• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for pfor and for_loop."""
16# pylint: disable=g-direct-tensorflow-import
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23import time
24
25from absl.testing import parameterized
26import numpy as np
27
28from tensorflow.core.example import example_pb2
29from tensorflow.core.example import feature_pb2
30from tensorflow.python.client import session
31from tensorflow.python.eager import backprop
32from tensorflow.python.eager import context
33from tensorflow.python.eager import def_function
34from tensorflow.python.framework import composite_tensor
35from tensorflow.python.framework import config
36from tensorflow.python.framework import constant_op
37from tensorflow.python.framework import dtypes
38from tensorflow.python.framework import indexed_slices
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import sparse_tensor
41from tensorflow.python.framework import tensor_shape
42from tensorflow.python.framework import tensor_spec
43from tensorflow.python.framework import test_util
44from tensorflow.python.framework import type_spec
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import bitwise_ops
47from tensorflow.python.ops import cond_v2
48from tensorflow.python.ops import control_flow_ops
49from tensorflow.python.ops import control_flow_v2_toggles
50from tensorflow.python.ops import data_flow_ops
51from tensorflow.python.ops import functional_ops
52from tensorflow.python.ops import gen_dataset_ops
53from tensorflow.python.ops import gen_list_ops
54from tensorflow.python.ops import gen_nn_ops
55from tensorflow.python.ops import gradient_checker_v2
56from tensorflow.python.ops import gradients as gradient_ops
57from tensorflow.python.ops import image_ops
58from tensorflow.python.ops import list_ops
59from tensorflow.python.ops import logging_ops
60from tensorflow.python.ops import map_fn
61from tensorflow.python.ops import math_ops
62from tensorflow.python.ops import nn
63from tensorflow.python.ops import parsing_ops
64from tensorflow.python.ops import random_ops
65from tensorflow.python.ops import resource_variable_ops
66from tensorflow.python.ops import rnn
67from tensorflow.python.ops import rnn_cell
68from tensorflow.python.ops import stateless_random_ops
69from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
70from tensorflow.python.ops import tensor_array_ops
71from tensorflow.python.ops import variables
72from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
73from tensorflow.python.ops.parallel_for.test_util import PForTestCase
74from tensorflow.python.ops.ragged import ragged_tensor
75from tensorflow.python.ops.signal import fft_ops
76from tensorflow.python.platform import test
77from tensorflow.python.util import nest
78
79
80@test_util.run_all_in_graph_and_eager_modes
81@test_util.with_control_flow_v2
82class PForTest(PForTestCase):
83
84  def test_op_conversion_fallback_to_while_loop(self):
85    # Note that we used top_k op for this test. If a converter gets defined for
86    # it, we will need to find another op for which a converter has not been
87    # defined.
88    x = random_ops.random_uniform([3, 2, 4])
89
90    def loop_fn(i):
91      x_i = array_ops.gather(x, i)
92      return nn.top_k(x_i)
93
94    with self.assertRaisesRegex(ValueError, "No pfor vectorization"):
95      self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False)
96    self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True)
97
98  def test_parallel_iterations(self):
99    for parallel_iterations in [2, 3, 8, 10]:
100      x = random_ops.random_uniform([8, 3])
101
102      # pylint: disable=cell-var-from-loop
103      def loop_fn(i):
104        return array_ops.gather(x, i)
105
106      # pylint: enable=cell-var-from-loop
107
108      self._test_loop_fn(loop_fn, 8, parallel_iterations=parallel_iterations)
109      self._test_loop_fn(
110          loop_fn,
111          4 * constant_op.constant(2),
112          parallel_iterations=parallel_iterations)
113
114  def test_parallel_iterations_preserves_static_shape(self):
115    for parallel_iterations in [2, 3, 8, 10]:
116      x = pfor_control_flow_ops.pfor(
117          lambda _: random_ops.random_uniform([2, 3]),
118          8,
119          parallel_iterations=parallel_iterations)
120      self.assertAllEqual(x.shape, [8, 2, 3])
121
122  def test_parallel_iterations_zero(self):
123    with self.assertRaisesRegex(ValueError, "positive integer"):
124      pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=0)
125    with self.assertRaisesRegex(TypeError, "positive integer"):
126      pfor_control_flow_ops.for_loop(
127          lambda i: 1, dtypes.int32, 8, parallel_iterations=0)
128
129  def test_parallel_iterations_one(self):
130    with self.assertRaisesRegex(ValueError, "Use for_loop instead"):
131      pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=1)
132
133  def test_vectorized_map(self):
134
135    def compute(x):
136      return math_ops.reduce_mean(x, axis=0, keepdims=True)
137
138    result = pfor_control_flow_ops.vectorized_map(compute,
139                                                  array_ops.ones((10, 5, 3)))
140    self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
141
142  def test_vectorized_map_with_dynamic_shape(self):
143
144    def compute(x):
145      return math_ops.reduce_mean(x, axis=0, keepdims=True)
146
147    x = array_ops.placeholder_with_default(
148        array_ops.ones((10, 5, 3)), shape=None)
149    result = pfor_control_flow_ops.vectorized_map(compute, x)
150    self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
151
152  def test_vectorized_map_broadcasts_unit_dimensions(self):
153    convert_with_static_shape = ops.convert_to_tensor
154    convert_with_dynamic_shape = (
155        lambda x: array_ops.placeholder_with_default(x, shape=None))
156
157    for convert in (convert_with_static_shape, convert_with_dynamic_shape):
158      a = convert([3.1])
159      b = convert([-2., 6., 9.])
160
161      # One elem with leading unit dimension.
162      a_plus_1 = pfor_control_flow_ops.vectorized_map(lambda a: a + 1, a)
163      self.assertAllEqual(*self.evaluate((a_plus_1, a + 1)))
164
165      # Two elems, both with leading unit dimension.
166      a_plus_a = pfor_control_flow_ops.vectorized_map(sum, (a, a))
167      self.assertAllEqual(*self.evaluate((a_plus_a, a + a)))
168
169      # Elem w/ unit dimension broadcast against elem with batch dim.
170      a_plus_b = pfor_control_flow_ops.vectorized_map(sum, (a, b))
171      self.assertAllEqual(*self.evaluate((a_plus_b, a + b)))
172
173  def test_vectorized_map_example_1(self):
174
175    def outer_product(a):
176      return math_ops.tensordot(a, a, 0)
177
178    batch_size = 100
179    a = array_ops.ones((batch_size, 32, 32))
180    c = pfor_control_flow_ops.vectorized_map(outer_product, a)
181    self.assertAllEqual((batch_size, 32, 32, 32, 32), c.shape)
182
183  def test_disable_tf_function(self):
184    def_function.run_functions_eagerly(True)
185    # vectorized_map should ignore disabling tf.functions
186    self.assertTrue(def_function.functions_run_eagerly())
187    self.assertAllEqual([0, 1, 4, 9],
188                        pfor_control_flow_ops.vectorized_map(
189                            lambda x: x * x, math_ops.range(4)))
190    self.assertTrue(def_function.functions_run_eagerly())
191    def_function.run_functions_eagerly(False)
192
193
194@test_util.run_all_in_graph_and_eager_modes
195class IndexedSlicesTest(PForTestCase):
196
197  def test_indexed_slices(self):
198
199    def loop_fn(i):
200      return indexed_slices.IndexedSlices(
201          indices=i, values=array_ops.reshape(i, [1]), dense_shape=[3, 1])
202
203    self._test_loop_fn(loop_fn, 2)
204
205  def test_indexed_slices_components(self):
206
207    def loop_fn(i):
208      slices = indexed_slices.IndexedSlices(
209          indices=i, values=array_ops.reshape(i, [1]), dense_shape=[3, 1])
210      # Note that returning the components inside the slice avoids
211      # densification, which may be more efficient.
212      return slices.values, slices.indices
213
214    self._test_loop_fn(loop_fn, 2)
215
216
217@test_util.run_all_in_graph_and_eager_modes
218class ReductionTest(PForTestCase):
219
220  def test_reduce(self):
221
222    def reduce_fn(p, q):
223      return math_ops.reduce_mean(p + q, axis=0)
224
225    x = random_ops.random_uniform([4, 3, 2])
226    y = random_ops.random_uniform([4, 3, 2])
227
228    def loop_fn(i, pfor_config):
229      x_i = array_ops.gather(x, i)
230      y_i = array_ops.gather(y, i)
231      reduced = pfor_config.reduce(reduce_fn, x_i, y_i)
232      return reduced + x_i
233
234    output = pfor_control_flow_ops.pfor(loop_fn, 4)
235    ans = reduce_fn(x, y) + x
236    output_val, ans_val = self.evaluate([output, ans])
237    self.assertAllClose(ans_val, output_val)
238
239  def test_reduce_concat(self):
240    x = random_ops.random_uniform([8, 3])
241
242    def loop_fn(i, pfor_config):
243      x_i = array_ops.gather(x, i)
244      vectorized_value = pfor_config.reduce_concat(x_i)
245      mean_value = math_ops.reduce_mean(vectorized_value, axis=0)
246      return x_i - mean_value
247
248    output = pfor_control_flow_ops.pfor(loop_fn, 8)
249    ans = x - math_ops.reduce_mean(x, axis=0)
250    output_val, ans_val = self.evaluate([output, ans])
251    self.assertAllClose(ans_val, output_val)
252
253  def test_reduce_mean(self):
254    x = random_ops.random_uniform([8, 3])
255
256    def loop_fn(i, pfor_config):
257      x_i = array_ops.gather(x, i)
258      return x_i - pfor_config.reduce_mean(x_i)
259
260    output = pfor_control_flow_ops.pfor(loop_fn, 8)
261    ans = x - math_ops.reduce_mean(x, axis=0)
262    output_val, ans_val = self.evaluate([output, ans])
263    self.assertAllClose(ans_val, output_val)
264
265  def test_reduce_sum(self):
266    x = random_ops.random_uniform([8, 3])
267
268    def loop_fn(i, pfor_config):
269      x_i = array_ops.gather(x, i)
270      return x_i - pfor_config.reduce_sum(x_i)
271
272    output = pfor_control_flow_ops.pfor(loop_fn, 8)
273    ans = x - math_ops.reduce_sum(x, axis=0)
274    output_val, ans_val = self.evaluate([output, ans])
275    self.assertAllClose(ans_val, output_val)
276
277  def test_reduce_class(self):
278    x = random_ops.random_uniform([8, 3])
279
280    class LoopFn(object):
281
282      def __init__(self):
283        pass
284
285      def __call__(self, i, pfor_config):
286        x_i = array_ops.gather(x, i)
287        return x_i - pfor_config.reduce_mean(x_i)
288
289    output = pfor_control_flow_ops.pfor(LoopFn(), 8)
290    ans = x - math_ops.reduce_mean(x, axis=0)
291    output_val, ans_val = self.evaluate([output, ans])
292    self.assertAllClose(ans_val, output_val)
293
294  def test_reduce_functools_partial(self):
295    x = random_ops.random_uniform([8, 3])
296
297    def fn(i, pfor_config, dummy=None):
298      del dummy
299      x_i = array_ops.gather(x, i)
300      return x_i - pfor_config.reduce_mean(x_i)
301
302    loop_fn = functools.partial(fn, dummy=1)
303    output = pfor_control_flow_ops.pfor(loop_fn, 8)
304    ans = x - math_ops.reduce_mean(x, axis=0)
305    output_val, ans_val = self.evaluate([output, ans])
306    self.assertAllClose(ans_val, output_val)
307
308  def test_parallel_iterations(self):
309    x = random_ops.random_uniform([8, 3])
310
311    def loop_fn(i, pfor_config):
312      x_i = array_ops.gather(x, i)
313      return pfor_config.reduce_sum(x_i)
314
315    with self.assertRaisesRegex(ValueError,
316                                "parallel_iterations currently unsupported"):
317      pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2)
318
319  def test_var_loop_len(self):
320    if context.executing_eagerly():
321      self.skipTest("Variable length not possible under eager execution.")
322
323    x = random_ops.random_uniform([8, 3])
324
325    def loop_fn(i, pfor_config):
326      return pfor_config.reduce_sum(array_ops.gather(x, i))
327
328    num_iters = array_ops.placeholder(dtypes.int32)
329    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
330    with self.cached_session() as sess:
331      sess.run(pfor, feed_dict={num_iters: 8})
332
333
334@test_util.run_all_in_graph_and_eager_modes
335class BitwiseTest(PForTestCase):
336
337  def test_unary_cwise(self):
338    for op in [bitwise_ops.invert]:
339      x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32)
340
341      # pylint: disable=cell-var-from-loop
342      def loop_fn(i):
343        x1 = array_ops.gather(x, i)
344        return op(x1)
345
346      # pylint: enable=cell-var-from-loop
347
348      self._test_loop_fn(loop_fn, 3)
349
350  def test_binary_cwise(self):
351    binary_ops = [
352        bitwise_ops.bitwise_and,
353        bitwise_ops.bitwise_or,
354        bitwise_ops.bitwise_xor,
355        bitwise_ops.left_shift,
356        bitwise_ops.right_shift,
357    ]
358    for op in binary_ops:
359      x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32)
360      y = random_ops.random_uniform([3, 5], maxval=10, dtype=dtypes.int32)
361
362      output_dtypes = []
363
364      # pylint: disable=cell-var-from-loop
365      def loop_fn(i):
366        x1 = array_ops.gather(x, i)
367        y1 = array_ops.gather(y, i)
368        outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)]
369        del output_dtypes[:]
370        output_dtypes.extend(t.dtype for t in outputs)
371        return outputs
372
373      # pylint: enable=cell-var-from-loop
374      self._test_loop_fn(loop_fn, 3)
375
376
377@test_util.run_all_in_graph_and_eager_modes
378class ImageTest(PForTestCase):
379
380  def test_adjust_contrast(self):
381    images = random_ops.random_uniform([3, 2, 4, 4, 3])
382
383    def loop_fn(i):
384      image = array_ops.gather(images, i)
385      return image_ops.adjust_contrast(image, 2.0)
386
387    self._test_loop_fn(loop_fn, 3)
388
389  def test_adjust_hue(self):
390    images = random_ops.random_uniform([3, 2, 4, 4, 3])
391
392    def loop_fn(i):
393      image = array_ops.gather(images, i)
394      return image_ops.adjust_hue(image, .25)
395
396    self._test_loop_fn(loop_fn, 3)
397
398  def test_adjust_saturation(self):
399    images = random_ops.random_uniform([3, 2, 4, 4, 3])
400
401    def loop_fn(i):
402      image = array_ops.gather(images, i)
403      return image_ops.adjust_saturation(image, 0.1)
404
405    self._test_loop_fn(loop_fn, 3)
406
407
408@test_util.run_all_in_graph_and_eager_modes
409class NNTest(PForTestCase):
410
411  def test_conv2d(self):
412    x = random_ops.random_uniform([3, 2, 12, 12, 3])
413    filt = random_ops.random_uniform([3, 3, 3, 7])
414
415    def loop_fn(i):
416      x1 = array_ops.gather(x, i)
417      return nn.conv2d(
418          x1, filt, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC")
419
420    self._test_loop_fn(loop_fn, 3)
421
422  def test_conv2d_backprop_input(self):
423    x_shape = [2, 12, 12, 3]
424    filt = random_ops.random_uniform([3, 3, 3, 7])
425    grad = random_ops.random_uniform([3, 2, 5, 5, 7])
426
427    def loop_fn(i):
428      grad1 = array_ops.gather(grad, i)
429      return nn.conv2d_backprop_input(
430          x_shape,
431          filt,
432          grad1,
433          strides=[1, 2, 2, 1],
434          padding="VALID",
435          data_format="NHWC")
436
437    self._test_loop_fn(loop_fn, 3)
438
439  def test_conv2d_backprop_filter(self):
440    x = random_ops.random_uniform([3, 2, 12, 12, 3])
441    x_0 = array_ops.gather(x, 0)
442    filter_sizes = [3, 3, 3, 7]
443    grad = random_ops.random_uniform([3, 2, 5, 5, 7])
444
445    def loop_fn(i):
446      x_i = array_ops.gather(x, i)
447      grad_i = array_ops.gather(grad, i)
448      return [
449          nn.conv2d_backprop_filter(
450              inp,
451              filter_sizes,
452              grad_i,
453              strides=[1, 2, 2, 1],
454              padding="VALID",
455              data_format="NHWC") for inp in [x_i, x_0]
456      ]
457
458    self._test_loop_fn(loop_fn, 3)
459
460  def test_avg_pool(self):
461    with backprop.GradientTape(persistent=True) as g:
462      x = random_ops.random_uniform([3, 2, 12, 12, 3])
463      g.watch(x)
464      ksize = [1, 3, 3, 1]
465
466    def loop_fn(i):
467      with g:
468        x1 = array_ops.gather(x, i)
469        output = nn.avg_pool(
470            x1,
471            ksize,
472            strides=[1, 2, 2, 1],
473            padding="VALID",
474            data_format="NHWC")
475        loss = nn.l2_loss(output)
476      return output, g.gradient(loss, x1)
477
478    self._test_loop_fn(loop_fn, 3)
479
480  def test_avg_pool3d(self):
481    with backprop.GradientTape(persistent=True) as g:
482      x = random_ops.random_uniform([5, 3, 7, 6, 6, 5])
483      g.watch(x)
484      ksize = [1, 2, 2, 2, 1]
485      strides = [1, 2, 2, 2, 1]
486
487    def loop_fn(i):
488      with g:
489        x1 = array_ops.gather(x, i)
490        output = nn.avg_pool3d(
491            x1, ksize, strides=strides, padding="VALID", data_format="NDHWC")
492        loss = nn.l2_loss(output)
493      return output, g.gradient(loss, x1)
494
495    self._test_loop_fn(loop_fn, 3)
496
497  def test_max_pool(self):
498    with backprop.GradientTape(persistent=True) as g:
499      x = random_ops.random_uniform([3, 2, 12, 12, 3])
500      g.watch(x)
501      ksize = [1, 3, 3, 1]
502      strides = [1, 2, 2, 1]
503
504    def loop_fn(i):
505      with g:
506        x1 = array_ops.gather(x, i)
507        output = nn.max_pool(
508            x1, ksize, strides=strides, padding="VALID", data_format="NHWC")
509        loss = nn.l2_loss(output)
510        ones = array_ops.ones_like(output)
511        g.watch(ones)
512        grad = g.gradient(loss, x1, output_gradients=ones)
513      grad_grad = g.gradient(grad, ones)
514      return output, grad, grad_grad
515
516    self._test_loop_fn(loop_fn, 3)
517
518  def test_max_pool_v2(self):
519    with backprop.GradientTape(persistent=True) as g:
520      x = random_ops.random_uniform([3, 2, 12, 12, 3])
521      g.watch(x)
522      ksize = [1, 3, 3, 1]
523      strides = [1, 2, 2, 1]
524
525    def loop_fn(i):
526      with g:
527        x1 = array_ops.gather(x, i)
528        output = gen_nn_ops.max_pool_v2(
529            x1, ksize, strides=strides, padding="VALID", data_format="NHWC")
530        loss = nn.l2_loss(output)
531        ones = array_ops.ones_like(output)
532        g.watch(ones)
533        grad = g.gradient(loss, x1, output_gradients=ones)
534      grad_grad = g.gradient(grad, ones)
535      return output, grad, grad_grad
536
537    self._test_loop_fn(loop_fn, 3)
538
539  def test_max_pool3d(self):
540    with backprop.GradientTape(persistent=True) as g:
541      x = random_ops.random_uniform([3, 3, 2, 12, 12, 3])
542      g.watch(x)
543      ksize = [1, 1, 3, 3, 1]
544      strides = [1, 1, 2, 2, 1]
545
546    def loop_fn(i):
547      with g:
548        x1 = array_ops.gather(x, i)
549        output = nn.max_pool3d(
550            x1, ksize, strides=strides, padding="VALID", data_format="NDHWC")
551        loss = nn.l2_loss(output)
552        ones = array_ops.ones_like(output)
553        g.watch(ones)
554        grad = g.gradient(loss, x1, output_gradients=ones)
555      grad_grad = g.gradient(grad, ones)
556      return output, grad, grad_grad
557
558    self._test_loop_fn(loop_fn, 3)
559
560  def test_fused_batch_norm(self):
561    data_formats = ["NHWC"]
562    if test.is_gpu_available():
563      data_formats.append("NCHW")
564    for is_training in (True, False):
565      for data_format in data_formats:
566        with backprop.GradientTape(persistent=True) as g:
567          if data_format == "NCHW":
568            x = random_ops.random_uniform([3, 1, 2, 5, 5])
569          else:
570            x = random_ops.random_uniform([3, 1, 5, 5, 2])
571          g.watch(x)
572          scale = random_ops.random_uniform([2])
573          g.watch(scale)
574          offset = random_ops.random_uniform([2])
575          g.watch(offset)
576          mean = None if is_training else random_ops.random_uniform([2])
577          variance = None if is_training else random_ops.random_uniform([2])
578
579        # pylint: disable=cell-var-from-loop
580        def loop_fn(i):
581          with g:
582            x1 = array_ops.gather(x, i)
583            outputs = nn.fused_batch_norm(
584                x1,
585                scale,
586                offset,
587                mean=mean,
588                variance=variance,
589                epsilon=0.01,
590                data_format=data_format,
591                is_training=is_training)
592            outputs = list(outputs)
593            # We only test the first value of outputs when is_training is
594            # False. It looks like CPU and GPU have different outputs for
595            # batch_mean and batch_variance for this case.
596            if not is_training:
597              outputs[1] = constant_op.constant(0.)
598              outputs[2] = constant_op.constant(0.)
599            loss = nn.l2_loss(outputs[0])
600          if is_training:
601            gradients = g.gradient(loss, [x1, scale, offset])
602          else:
603            gradients = [constant_op.constant(0.)] * 3
604          return outputs + gradients
605
606        # pylint: enable=cell-var-from-loop
607
608        self._test_loop_fn(loop_fn, 3)
609
610  def test_log_softmax(self):
611    logits = random_ops.random_uniform([3, 2, 4])
612
613    def loop_fn(i):
614      logits_i = array_ops.gather(logits, i)
615      return (nn.log_softmax(logits_i), nn.log_softmax(logits_i, axis=0),
616              nn.log_softmax(logits_i, axis=-1))
617
618    self._test_loop_fn(loop_fn, 3)
619
620  def test_softmax(self):
621    logits = random_ops.random_uniform([3, 2, 4])
622
623    def loop_fn(i):
624      logits_i = array_ops.gather(logits, i)
625      return (nn.softmax(logits_i), nn.softmax(logits_i, axis=0),
626              nn.softmax(logits_i, axis=-1))
627
628    self._test_loop_fn(loop_fn, 3)
629
630  def test_softmax_cross_entropy_with_logits(self):
631    with backprop.GradientTape(persistent=True) as g:
632      logits = random_ops.random_uniform([3, 2, 4])
633      g.watch(logits)
634      labels = random_ops.random_uniform([3, 2, 4])
635      labels /= math_ops.reduce_sum(labels, axis=[2], keepdims=True)
636
637    def loop_fn(i):
638      with g:
639        logits_i = array_ops.gather(logits, i)
640        labels_i = array_ops.gather(labels, i)
641        loss = nn.softmax_cross_entropy_with_logits(
642            labels=labels_i, logits=logits_i)
643        total_loss = math_ops.reduce_sum(loss)
644      return loss, g.gradient(total_loss, logits_i)
645
646    self._test_loop_fn(loop_fn, 3)
647
648  def test_sparse_softmax_cross_entropy_with_logits(self):
649    logits = random_ops.random_uniform([3, 2, 4])
650    labels = random_ops.random_uniform(
651        shape=[3, 2], maxval=4, dtype=dtypes.int32)
652
653    def loop_fn(i):
654      logits_i = array_ops.gather(logits, i)
655      labels_i = array_ops.gather(labels, i)
656      loss = nn.sparse_softmax_cross_entropy_with_logits(
657          labels=labels_i, logits=logits_i)
658      return loss
659
660    self._test_loop_fn(loop_fn, 3)
661
662
663class RandomTest(PForTestCase):
664
665  # The random values generated in the two implementations are not guaranteed to
666  # match. So we only check the returned shapes.
667  def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
668    outputs = self._run_targets(targets1, targets2)
669    n = len(outputs) // 2
670    for i in range(n):
671      self.assertAllEqual(outputs[i].shape, outputs[i + n].shape)
672
673  def test_random_uniform(self):
674
675    def loop_fn(_):
676      return random_ops.random_uniform([3])
677
678    self._test_loop_fn(loop_fn, 5)
679
680  def test_random_uniform_int(self):
681
682    def loop_fn(_):
683      return random_ops.random_uniform([3], maxval=1, dtype=dtypes.int32)
684
685    self._test_loop_fn(loop_fn, 5)
686
687  def test_random_standard_normal(self):
688
689    def loop_fn(_):
690      return random_ops.random_normal([3])
691
692    self._test_loop_fn(loop_fn, 5)
693
694  def test_truncated_normal(self):
695
696    def loop_fn(_):
697      return random_ops.truncated_normal([3])
698
699    self._test_loop_fn(loop_fn, 5)
700
701  def test_random_gamma_invariant_alpha(self):
702
703    def loop_fn(_):
704      return random_ops.random_gamma([3], alpha=[0.5])
705
706    self._test_loop_fn(loop_fn, 5)
707
708  def test_random_gamma_varying_alpha(self):
709    alphas = math_ops.exp(random_ops.random_normal([5, 3, 2]))
710
711    def loop_fn(i):
712      alphas_i = array_ops.gather(alphas, i)
713      # Test both scalar and non-scalar params and shapes.
714      return (random_ops.random_gamma(alpha=alphas_i[0, 0], shape=[]),
715              random_ops.random_gamma(alpha=alphas_i, shape=[]),
716              random_ops.random_gamma(alpha=alphas_i[0, 0], shape=[3]),
717              random_ops.random_gamma(alpha=alphas_i, shape=[3]))
718
719    self._test_loop_fn(loop_fn, 5)
720
721  def test_random_poisson_v2_invariant_rate(self):
722
723    def loop_fn(_):
724      return random_ops.random_poisson(lam=[1.3], shape=[3])
725
726    self._test_loop_fn(loop_fn, 5)
727
728  def test_random_poisson_v2_varying_rate(self):
729    rates = math_ops.exp(random_ops.random_normal([5, 3, 2]))
730
731    def loop_fn(i):
732      rates_i = array_ops.gather(rates, i)
733      # Test both scalar and non-scalar params and shapes.
734      return (random_ops.random_poisson(lam=rates_i[0, 0], shape=[]),
735              random_ops.random_poisson(lam=rates_i, shape=[]),
736              random_ops.random_poisson(lam=rates_i[0, 0], shape=[3]),
737              random_ops.random_poisson(lam=rates_i, shape=[3]))
738
739    self._test_loop_fn(loop_fn, 5)
740
741  def test_random_multinomial_invariant_logits(self):
742
743    def loop_fn(_):
744      return random_ops.categorical(logits=[[1., -1.]], num_samples=3)
745
746    self._test_loop_fn(loop_fn, 5)
747
748  def test_random_multinomial_varying_logits(self):
749    logits = random_ops.random_normal([5, 3, 2])
750
751    def loop_fn(i):
752      logits_i = array_ops.gather(logits, i)
753      return random_ops.categorical(logits_i, num_samples=3)
754
755    self._test_loop_fn(loop_fn, 5)
756
757
758class StatelessRandomTest(PForTestCase):
759
760  # This test currently only tests that the vectorized and non-vectorized
761  # outputs have same shapes. This is needed since under XLA compilation,
762  # stateless random numbers can generate different random numbers.
763  # TODO(agarwal): switch to checking for actual values matching once
764  # b/149402339 is resolved.
765  def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
766    outputs = self._run_targets(targets1, targets2)
767    n = len(outputs) // 2
768    for i in range(n):
769      self.assertAllEqual(outputs[i].shape, outputs[i + n].shape)
770
771  # TODO(agarwal): add tests for other random functions
772  def test_multinomial(self):
773    seeds = [[1, 2], [3, 4]]
774    logits = random_ops.random_uniform([2, 3, 4])
775
776    def loop_fn(i):
777      logits_0 = array_ops.gather(logits, 0)
778      logits_i = array_ops.gather(logits, i)
779      seeds_0 = array_ops.gather(seeds, 0)
780      seeds_i = array_ops.gather(seeds, i)
781      return (stateless_random_ops.stateless_categorical(
782          logits=logits_i, num_samples=3, seed=seeds_i),
783              stateless_random_ops.stateless_categorical(
784                  logits=logits_i, num_samples=3, seed=seeds_0),
785              stateless_random_ops.stateless_categorical(
786                  logits=logits_0, num_samples=3, seed=seeds_i),
787              stateless_random_ops.stateless_categorical(
788                  logits=logits_0, num_samples=3, seed=seeds_0))
789
790    self._test_loop_fn(loop_fn, 2)
791
792
793class LoggingTest(PForTestCase):
794
795  @test_util.run_v1_only("b/122612051")
796  def test_print(self):
797    x = random_ops.random_uniform([3, 5])
798
799    def loop_fn(i):
800      x1 = array_ops.gather(x, i)
801      return logging_ops.Print(
802          x1, [x1, "x1", array_ops.shape(x1)], summarize=10)
803
804    self._test_loop_fn(loop_fn, 3)
805
806  def test_assert(self):
807
808    def loop_fn(i):
809      return control_flow_ops.Assert(i < 10, [i, [10], [i + 1]])
810
811    # TODO(agarwal): make this work with for_loop.
812    with session.Session() as sess:
813      sess.run(pfor_control_flow_ops.pfor(loop_fn, 3))
814      sess.run(pfor_control_flow_ops.pfor(
815          lambda i, pfor_config: loop_fn(i), 3))
816
817
818class TensorArrayTest(PForTestCase):
819
820  def setUp(self):
821    self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
822    control_flow_v2_toggles.disable_control_flow_v2()
823    super(TensorArrayTest, self).setUp()
824
825  def tearDown(self):
826    if self._enabled:
827      control_flow_v2_toggles.enable_control_flow_v2()
828    super(TensorArrayTest, self).tearDown()
829
830  @test_util.run_v1_only("b/122612051")
831  def test_create_outside_and_read(self):
832
833    ta = tensor_array_ops.TensorArray(
834        dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1)
835
836    def loop_fn(i):
837      return ta.read(i), ta.read(0)
838
839    self._test_loop_fn(loop_fn, 2)
840
841  @test_util.run_v1_only("b/122612051")
842  def test_create_outside_and_gather(self):
843
844    ta = tensor_array_ops.TensorArray(
845        dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1)
846
847    def loop_fn(i):
848      return ta.gather([i]), ta.gather([0, 1])
849
850    self._test_loop_fn(loop_fn, 2)
851
852  @test_util.run_v1_only("b/122612051")
853  def test_create_outside_and_write_and_scatter(self):
854
855    t = tensor_array_ops.TensorArray(dtypes.int32, 10, clear_after_read=False)
856    handle = t.handle
857
858    def loop_fn(i):
859      ta = t.write(i + 2, 2 * i).write(i, 5)
860      ta = ta.scatter([4 + i], [4]).scatter([6 + i, 8 + i], [6 + i, 8 + i])
861      return ta.flow
862
863    t1 = pfor_control_flow_ops.pfor(loop_fn, iters=2)
864    out1 = tensor_array_ops.TensorArray(
865        dtypes.int32, handle=handle, flow=t1[-1]).stack()
866    output1 = self._run_targets(out1)
867
868    t2 = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, iters=2)
869    out2 = tensor_array_ops.TensorArray(
870        dtypes.int32, handle=handle, flow=t2[-1]).stack()
871    output2 = self._run_targets(out2)
872    self.assertAllClose(output2, output1)
873
874  @test_util.run_v1_only("b/122612051")
875  def test_create_inside_and_write(self):
876
877    def loop_fn(i):
878      # TODO(agarwal): switching the order of writes to ta1 does not work.
879      ta1 = tensor_array_ops.TensorArray(dtypes.int32, 2).write(0,
880                                                                i).write(1, 1)
881      ta2 = tensor_array_ops.TensorArray(dtypes.int32, 1).write(0, 1)
882      return ta1.stack(), ta2.stack()
883
884    self._test_loop_fn(loop_fn, 3)
885
886  @test_util.run_v1_only("b/122612051")
887  def test_create_inside_and_scatter(self):
888
889    def loop_fn(i):
890      # TODO(agarwal): switching the order of scatter to ta1 does not work.
891      ta1 = tensor_array_ops.TensorArray(dtypes.int32,
892                                         2).scatter([0],
893                                                    [[i, 2]]).scatter([1],
894                                                                      [[1, 2]])
895      ta2 = tensor_array_ops.TensorArray(dtypes.int32,
896                                         2).scatter([0], [3]).scatter([1], [4])
897      return ta1.stack(), ta2.stack()
898
899    self._test_loop_fn(loop_fn, 3)
900
901  @test_util.run_v1_only("b/122612051")
902  def test_create_inside_and_read(self):
903
904    def loop_fn(i):
905      ta1 = tensor_array_ops.TensorArray(
906          dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1)
907      ta2 = tensor_array_ops.TensorArray(
908          dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2)
909      # TODO(agarwal): ta1.read(i) currently is not supported.
910      return ta1.read(0), ta2.read(0), ta2.read(i)
911
912    self._test_loop_fn(loop_fn, 2)
913
914  @test_util.run_v1_only("b/122612051")
915  def test_create_inside_and_gather(self):
916
917    def loop_fn(i):
918      ta1 = tensor_array_ops.TensorArray(
919          dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1)
920      ta2 = tensor_array_ops.TensorArray(
921          dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2)
922      # TODO(agarwal): ta1.read(i) currently is not supported.
923      return ta1.gather([0, 1]), ta2.gather([0, 1]), ta2.gather([i])
924
925    self._test_loop_fn(loop_fn, 2)
926
927  @test_util.run_v1_only("b/122612051")
928  def test_grad(self):
929    x = random_ops.random_uniform([3, 2])
930    ta = tensor_array_ops.TensorArray(
931        dtypes.float32, 3, clear_after_read=False).unstack(x)
932    y = math_ops.square(ta.stack())
933
934    def loop_fn(i):
935      y_i = array_ops.gather(y, i)
936      grad = gradient_ops.gradients(y_i, x)[0]
937      return array_ops.gather(grad, i)
938
939    t1 = pfor_control_flow_ops.pfor(loop_fn, iters=3)
940    # y = x * x. Hence dy/dx = 2 * x.
941    actual_grad = 2.0 * x
942    with session.Session() as sess:
943      actual_grad, computed_grad = sess.run([t1, actual_grad])
944      self.assertAllClose(actual_grad, computed_grad)
945
946
947@test_util.run_all_in_graph_and_eager_modes
948class TensorListTest(PForTestCase):
949
950  def test_create_outside_and_write(self):
951    handle1 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
952    handle2 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
953
954    def loop_fn(i):
955      h1 = list_ops.tensor_list_set_item(handle1, 0, i)
956      h1 = list_ops.tensor_list_set_item(h1, 1, 1)
957      h2 = list_ops.tensor_list_set_item(handle2, 0, 1)
958      return (list_ops.tensor_list_stack(h1, dtypes.int32),
959              list_ops.tensor_list_stack(h2, dtypes.int32))
960
961    self._test_loop_fn(loop_fn, 3)
962
963  def test_create_inside_and_write(self):
964
965    def loop_fn(i):
966      h1 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
967      h1 = list_ops.tensor_list_set_item(h1, 0, i)
968      h1 = list_ops.tensor_list_set_item(h1, 1, 1)
969      h2 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
970      h2 = list_ops.tensor_list_set_item(h2, 0, 1)
971      return (list_ops.tensor_list_stack(h1, dtypes.int32),
972              list_ops.tensor_list_stack(h2, dtypes.int32))
973
974    self._test_loop_fn(loop_fn, 3)
975
976  def test_create_outside_and_read(self):
977    handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
978    handle = list_ops.tensor_list_set_item(handle, 0, 0)
979    handle = list_ops.tensor_list_set_item(handle, 1, 1)
980
981    def loop_fn(i):
982      return (list_ops.tensor_list_get_item(handle, i, dtypes.int32),
983              list_ops.tensor_list_get_item(handle, 0, dtypes.int32),
984              list_ops.tensor_list_length(handle),
985              list_ops.tensor_list_element_shape(handle, dtypes.int32),
986              list_ops.tensor_list_element_shape(handle, dtypes.int64))
987
988    self._test_loop_fn(loop_fn, 2)
989
990  @test_util.disable_tfrt("b/180206304")
991  def test_create_inside_and_read(self):
992
993    def loop_fn(i):
994      handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
995      handle = list_ops.tensor_list_set_item(handle, 0, i)
996      handle = list_ops.tensor_list_set_item(handle, 1, 1)
997      return (list_ops.tensor_list_get_item(handle, 0, dtypes.int32),
998              list_ops.tensor_list_get_item(handle, i, dtypes.int32),
999              list_ops.tensor_list_length(handle),
1000              list_ops.tensor_list_element_shape(handle, dtypes.int32),
1001              list_ops.tensor_list_element_shape(handle, dtypes.int64))
1002
1003    self._test_loop_fn(loop_fn, 2)
1004
1005  def test_create_outside_and_push_back(self):
1006    h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1007
1008    def loop_fn(i):
1009      handle = list_ops.tensor_list_push_back(h, [i, 2])
1010      handle = list_ops.tensor_list_push_back(handle, [1, 2])
1011      handle = list_ops.tensor_list_push_back(handle, [1, 2])
1012      return list_ops.tensor_list_stack(handle, dtypes.int32)
1013
1014    self._test_loop_fn(loop_fn, 3)
1015
1016  def test_create_inside_and_push_back(self):
1017
1018    def loop_fn(i):
1019      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1020      handle = list_ops.tensor_list_push_back(handle, [i, 2])
1021      handle = list_ops.tensor_list_push_back(handle, [1, 2])
1022      return list_ops.tensor_list_stack(handle, dtypes.int32)
1023
1024    self._test_loop_fn(loop_fn, 3)
1025
1026  def test_pop_back_no_shape(self):
1027
1028    def loop_fn(i):
1029      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1030      handle = list_ops.tensor_list_push_back(handle, [1, 2])
1031      handle = list_ops.tensor_list_push_back(handle, [i, 2])
1032      handle, tensor = list_ops.tensor_list_pop_back(handle, dtypes.int32)
1033      return tensor, list_ops.tensor_list_stack(handle, dtypes.int32)
1034
1035    self._test_loop_fn(loop_fn, 3)
1036
1037  def test_pop_back_no_shape_capture(self):
1038    h = list_ops.tensor_list_reserve([2], 1, dtypes.int32)
1039    h = list_ops.tensor_list_push_back(h, [1, 2])
1040
1041    def loop_fn(i):
1042      handle, tensor = list_ops.tensor_list_pop_back(h, dtypes.int32)
1043      handle = list_ops.tensor_list_push_back(handle, [1, i])
1044      return tensor, list_ops.tensor_list_stack(handle, dtypes.int32)
1045
1046    self._test_loop_fn(loop_fn, 3)
1047
1048  def test_pop_back_with_shape(self):
1049
1050    @def_function.function
1051    def loop_fn(i):
1052      with backprop.GradientTape() as tape:
1053        handle = list_ops.tensor_list_reserve(None, 1, dtypes.float32)
1054        x = math_ops.cast(i, dtypes.float32)[None]
1055        tape.watch(x)
1056        handle = list_ops.tensor_list_push_back(handle, x)
1057        stacked = list_ops.tensor_list_stack(handle, dtypes.float32)
1058      list_grad = tape.gradient(stacked, x, x)
1059      self.assertEqual("TensorListPopBack", list_grad.op.type)
1060      return list_grad, stacked, list_grad.op.inputs[1]
1061
1062    self._test_loop_fn(loop_fn, 3)
1063
1064  def test_create_outside_and_scatter(self):
1065    h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1066
1067    def loop_fn(i):
1068      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h)
1069      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1070      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1071      return list_ops.tensor_list_stack(handle, dtypes.int32)
1072
1073    self._test_loop_fn(loop_fn, 3)
1074
1075  def test_create_inside_and_scatter(self):
1076
1077    def loop_fn(i):
1078      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1079      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle)
1080      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1081      return list_ops.tensor_list_stack(handle, dtypes.int32)
1082
1083    self._test_loop_fn(loop_fn, 3)
1084
1085  def test_create_outside_and_gather(self):
1086    handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1087    handle = list_ops.tensor_list_scatter([[2, 3]], [0], input_handle=handle)
1088    handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1089
1090    def loop_fn(i):
1091      return (list_ops.tensor_list_gather(handle, [0, 1], dtypes.int32),
1092              list_ops.tensor_list_gather(handle, [i], dtypes.int32))
1093
1094    self._test_loop_fn(loop_fn, 2)
1095
1096  def test_create_inside_and_gather(self):
1097
1098    def loop_fn(i):
1099      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1100      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle)
1101      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1102      return (list_ops.tensor_list_gather(handle, [0, 1], dtypes.int32),
1103              list_ops.tensor_list_gather(handle, [i], dtypes.int32))
1104
1105    self._test_loop_fn(loop_fn, 2)
1106
1107  def test_create_inside_and_concat(self):
1108
1109    def loop_fn(i):
1110      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1111      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle)
1112      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1113      return gen_list_ops.tensor_list_concat_v2(
1114          handle,
1115          element_dtype=dtypes.int32,
1116          element_shape=[2],
1117          leading_dims=[])
1118
1119    output = pfor_control_flow_ops.pfor(loop_fn, 2)
1120    self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0])
1121    self.assertAllClose([[2, 2], [2, 2]], output[1])
1122
1123  def test_create_outside_and_concat(self):
1124    h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1125
1126    def loop_fn(i):
1127      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h)
1128      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1129      return gen_list_ops.tensor_list_concat_v2(
1130          handle,
1131          element_dtype=dtypes.int32,
1132          element_shape=[2],
1133          leading_dims=[])
1134
1135    output = pfor_control_flow_ops.pfor(loop_fn, 2)
1136    self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0])
1137    self.assertAllClose([[2, 2], [2, 2]], output[1])
1138
1139  def test_tensor_list_from_tensor(self):
1140    t = random_ops.random_uniform([2, 3, 4])
1141
1142    def loop_fn(i):
1143      handle = list_ops.tensor_list_from_tensor(array_ops.gather(t, i), [4])
1144      return list_ops.tensor_list_stack(handle, t.dtype)
1145
1146    self._test_loop_fn(loop_fn, 2)
1147
1148  def test_tensor_list_reserve_while_loop(self):
1149    # Here a loop invariant TensorList is captured by a while_loop, which then
1150    # performs loop dependent operations on it, resulting in a loop variant
1151    # output. This forces stacking of the variant handle captured by the
1152    # while_loop.
1153    # We handle this particular case by forcing vectorization of
1154    # TensorListReserve operation.
1155    v2_enabled = control_flow_v2_toggles.control_flow_v2_enabled()
1156    control_flow_v2_toggles.enable_control_flow_v2()
1157
1158    def loop_fn(i):
1159      handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1160      _, out_handle = control_flow_ops.while_loop(
1161          lambda j, _: j < 2, lambda j, h:
1162          (j + 1, list_ops.tensor_list_set_item(h, j, i)), (0, handle))
1163      return list_ops.tensor_list_stack(out_handle, dtypes.int32)
1164
1165    self._test_loop_fn(loop_fn, 2)
1166    if not v2_enabled:
1167      control_flow_v2_toggles.disable_control_flow_v2()
1168
1169  def test_tensor_list_addn_already_stacked(self):
1170
1171    def loop_fn(i):
1172      l1 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1173      l1 = list_ops.tensor_list_set_item(l1, 0, i)
1174      l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1175      l2 = list_ops.tensor_list_set_item(l2, 1, i)
1176      return list_ops.tensor_list_stack(math_ops.add_n([l1, l2]), dtypes.int32)
1177
1178    self._test_loop_fn(loop_fn, 2)
1179
1180  def test_tensor_list_addn_stacking_required(self):
1181    l1 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1182    l1 = list_ops.tensor_list_set_item(l1, 1, 1)
1183
1184    def loop_fn(i):
1185      l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1186      l2 = list_ops.tensor_list_set_item(l2, 1, i)
1187      return list_ops.tensor_list_stack(
1188          math_ops.add_n([l1, l2]), dtypes.int32)
1189
1190    self._test_loop_fn(loop_fn, 2)
1191
1192
1193class OptionalTest(PForTestCase):
1194
1195  def test_optional_from_value(self):
1196
1197    def loop_fn(i):
1198      o = gen_dataset_ops.optional_from_value(
1199          [i, i + 1, constant_op.constant(3)])
1200      gen_dataset_ops.optional_none()
1201      return gen_dataset_ops.optional_get_value(
1202          o, [dtypes.int32, dtypes.int32, dtypes.int32],
1203          [[], [], []])
1204
1205    self._test_loop_fn(loop_fn, 2)
1206
1207
1208class StackTest(PForTestCase):
1209
1210  @test_util.run_v1_only("b/122612051")
1211  def test_stack_inside_loop_invariant(self):
1212
1213    def loop_fn(_):
1214      s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
1215      op1 = data_flow_ops.stack_push_v2(s, 1)
1216      with ops.control_dependencies([op1]):
1217        op2 = data_flow_ops.stack_push_v2(s, 2)
1218      with ops.control_dependencies([op2]):
1219        e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1220      with ops.control_dependencies([e2]):
1221        e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1222      return e1, e2
1223
1224    self._test_loop_fn(loop_fn, 2)
1225
1226  @test_util.run_v1_only("b/122612051")
1227  def test_stack_inside_push_loop_dependent(self):
1228
1229    def loop_fn(i):
1230      s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
1231      op1 = data_flow_ops.stack_push_v2(s, i)
1232      with ops.control_dependencies([op1]):
1233        op2 = data_flow_ops.stack_push_v2(s, 2)
1234      with ops.control_dependencies([op2]):
1235        e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1236      with ops.control_dependencies([e2]):
1237        e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1238      return e1, e2
1239
1240    self._test_loop_fn(loop_fn, 2)
1241
1242  @test_util.run_v1_only("b/122612051")
1243  def test_stack_outside_pop(self):
1244    s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
1245    op = data_flow_ops.stack_push_v2(s, 5)
1246    with ops.control_dependencies([op]):
1247      op = data_flow_ops.stack_push_v2(s, 6)
1248    with ops.control_dependencies([op]):
1249      op = data_flow_ops.stack_push_v2(s, 7)
1250
1251    def loop_fn(_):
1252      e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1253      with ops.control_dependencies([e1]):
1254        e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1255      return e1, e2
1256
1257    with ops.control_dependencies([op]):
1258      e1, e2 = pfor_control_flow_ops.pfor(loop_fn, iters=2)
1259    with ops.control_dependencies([e1, e2]):
1260      e3 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1261    v1, v2, v3 = self._run_targets([e1, e2, e3], run_init=False)
1262    self.assertAllEqual([7, 7], v1)
1263    self.assertAllEqual([6, 6], v2)
1264    self.assertAllEqual(5, v3)
1265
1266  @test_util.run_v1_only("b/122612051")
1267  def test_stack_outside_push(self):
1268    s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
1269
1270    def loop_fn(_):
1271      return data_flow_ops.stack_push_v2(s, 7)
1272
1273    with self.assertRaisesRegex(ValueError, "StackPushV2 not allowed.*"):
1274      pfor_control_flow_ops.pfor(loop_fn, iters=2)
1275
1276
1277# TODO(agarwal): test nested while_loops. This currently requires converting a
1278# tf.cond.
1279class WhileV1Test(PForTestCase):
1280
1281  def setUp(self):
1282    self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
1283    control_flow_v2_toggles.disable_control_flow_v2()
1284    super(WhileV1Test, self).setUp()
1285
1286  def tearDown(self):
1287    if self._enabled:
1288      control_flow_v2_toggles.enable_control_flow_v2()
1289    super(WhileV1Test, self).tearDown()
1290
1291  def test_while_outside_loop(self):
1292
1293    x = control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
1294
1295    def loop_fn(i):
1296      return x + i
1297
1298    self._test_loop_fn(loop_fn, 3)
1299
1300  @test_util.run_v1_only("b/122612051")
1301  def test_invariant_while(self):
1302
1303    def loop_fn(_):
1304      return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
1305
1306    self._test_loop_fn(loop_fn, 3)
1307
1308  @test_util.run_v1_only("b/122612051")
1309  def test_invariant_while_with_control_dependency(self):
1310
1311    def loop_fn(i):
1312      with ops.control_dependencies([i]):
1313        return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1,
1314                                           [0])
1315
1316    self._test_loop_fn(loop_fn, 3)
1317
1318  @test_util.run_v1_only("b/122612051")
1319  def test_while_with_stateful_ops(self):
1320
1321    def loop_fn(_):
1322      return control_flow_ops.while_loop(
1323          lambda j, x: j < 4, lambda j, x:
1324          (j + 1, x + random_ops.random_uniform([])), [0, 0.])[0]
1325
1326    self._test_loop_fn(loop_fn, 3)
1327
1328  @test_util.run_v1_only("b/122612051")
1329  def test_while_unstacked_condition(self):
1330
1331    def loop_fn(i):
1332      return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
1333                                         (j + 1, x + i), [0, 0])
1334
1335    self._test_loop_fn(loop_fn, 3)
1336
1337  @test_util.run_v1_only("b/122612051")
1338  def test_while(self):
1339    x = random_ops.random_uniform([3, 5])
1340    lengths = constant_op.constant([4, 0, 2])
1341
1342    def loop_fn(i):
1343      x_i = array_ops.gather(x, i)
1344      lengths_i = array_ops.gather(lengths, i)
1345
1346      _, total = control_flow_ops.while_loop(
1347          lambda j, _: j < lengths_i, lambda j, t:
1348          (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
1349      return total
1350
1351    self._test_loop_fn(loop_fn, 3)
1352
1353  @test_util.run_v1_only("b/122612051")
1354  def test_while_jacobian(self):
1355    x = random_ops.random_uniform([1, 3])
1356    y = random_ops.random_uniform([3, 3])
1357
1358    # out = x @ y @ y @ y @ y, where @ is matmul operator.
1359    _, out = control_flow_ops.while_loop(
1360        lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)),
1361        [0, x])
1362
1363    def loop_fn(i):
1364      out_i = array_ops.gather(out, i, axis=1)
1365      return array_ops.reshape(gradient_ops.gradients(out_i, x)[0], [-1])
1366
1367    out = pfor_control_flow_ops.pfor(loop_fn, iters=3)
1368
1369    # The above code does not work with tf.while_loop instead of pfor. So we
1370    # manually compute the expected output here.
1371    # Note that gradient of output w.r.t is (y @ y @ y @ y)^T.
1372    expected_output = y
1373    for _ in range(3):
1374      expected_output = math_ops.matmul(expected_output, y)
1375    expected_output = array_ops.transpose(expected_output, [1, 0])
1376
1377    with session.Session() as sess:
1378      out, expected = sess.run([out, expected_output])
1379      self.assertAllClose(expected, out)
1380
1381  @test_util.run_v1_only("b/122612051")
1382  def test_tensor_array_as_loop_variable(self):
1383
1384    def loop_fn(i):
1385
1386      def body(j, ta):
1387        ta = ta.write(j, i + j * j)
1388        return j + 1, ta
1389
1390      _, ta = control_flow_ops.while_loop(
1391          lambda j, _: j < 4, body,
1392          (0, tensor_array_ops.TensorArray(dtypes.int32, size=4)))
1393      return ta.stack()
1394
1395    self._test_loop_fn(loop_fn, 3)
1396
1397  @test_util.run_v1_only("b/122612051")
1398  def test_read_tensor_array_partitioned_indices(self):
1399    # Note that tensor array values are pfor loop dependent, and the while loop
1400    # termination condition is also dependent on pfor iteration.
1401    def loop_fn(i):
1402      ta = tensor_array_ops.TensorArray(dtypes.int32, size=6)
1403      ta = ta.unstack(i + list(range(5)))
1404
1405      def body(j, s):
1406        return j + 1, s + ta.read(j)
1407
1408      _, s = control_flow_ops.while_loop(lambda j, _: j < i, body, (0, 0))
1409      return s
1410
1411    self._test_loop_fn(loop_fn, 3)
1412
1413  @test_util.run_v1_only("b/122612051")
1414  def test_external_while_loop_grad(self):
1415    # Here we test that external while_loops that are extended from inside pfor
1416    # (due to gradient calls) are not actually converted. If the below was
1417    # converted all pfor iterations would write to the same tensor array
1418    # indices.
1419    x = constant_op.constant(1.)
1420
1421    def body(j, ta):
1422      ta = ta.write(j, x)
1423      return j + 1, ta
1424
1425    _, ta = control_flow_ops.while_loop(
1426        lambda j, _: j < 4, body,
1427        (0, tensor_array_ops.TensorArray(dtypes.float32, size=4)))
1428    out = ta.stack()
1429
1430    def loop_fn(i):
1431      out_i = array_ops.gather(out, i)
1432      return gradient_ops.gradients(out_i, x)[0]
1433
1434    with session.Session() as sess:
1435      # out is [x, x, x]. Hence the gradients should be [1, 1, 1].
1436      self.assertAllEqual([1, 1, 1],
1437                          sess.run(pfor_control_flow_ops.pfor(loop_fn, 3)))
1438
1439  @test_util.run_v1_only("b/122612051")
1440  def test_tensor_array_grad(self):
1441    inp = constant_op.constant(np.random.rand(3, 4, 2), dtype=dtypes.float32)
1442    ta = tensor_array_ops.TensorArray(dtypes.float32, size=3)
1443    ta = ta.unstack(inp)
1444
1445    def loop_fn(i):
1446
1447      def body(j, x):
1448        value = ta.gather([j])
1449        value = array_ops.gather(array_ops.reshape(value, [4, 2]), i)
1450        return j + 1, x + value
1451
1452      _, out = control_flow_ops.while_loop(lambda j, _: j < 3, body,
1453                                           (0, array_ops.zeros([2])))
1454      out = math_ops.reduce_prod(out)
1455      return out, gradient_ops.gradients(out, inp)[0]
1456
1457    pfor_out, pfor_out_grad = pfor_control_flow_ops.pfor(loop_fn, 4)
1458    # Note that tf.while_loop does not work in the setup above. So we manually
1459    # construct the equivalent computation of the above loops here.
1460    real_out = math_ops.reduce_sum(inp, axis=[0])
1461    real_out = math_ops.reduce_prod(real_out, axis=[1])
1462    # Note that gradients of real_out will accumulate the gradients across the
1463    # output value. Hence we do the same aggregation on pfor_out_grad.
1464    real_out_grad = gradient_ops.gradients(real_out, inp)[0]
1465    sum_pfor_out_grad = math_ops.reduce_sum(pfor_out_grad, axis=[0])
1466
1467    with session.Session() as sess:
1468      v1, v2, v1_grad, v2_grad = sess.run(
1469          [pfor_out, real_out, sum_pfor_out_grad, real_out_grad])
1470      self.assertAllClose(v1, v2)
1471      self.assertAllClose(v1_grad, v2_grad)
1472
1473
1474def dynamic_lstm_input_fn(batch_size, state_size, max_steps):
1475  # We make inputs and sequence_length constant so that multiple session.run
1476  # calls produce the same result.
1477  inputs = constant_op.constant(
1478      np.random.rand(batch_size, max_steps, state_size), dtype=dtypes.float32)
1479  sequence_length = np.random.randint(0, size=[batch_size], high=max_steps + 1)
1480  sequence_length = constant_op.constant(sequence_length, dtype=dtypes.int32)
1481  return inputs, sequence_length
1482
1483
1484def create_dynamic_lstm(cell_fn, batch_size, state_size, max_steps):
1485  cell = cell_fn(state_size)
1486  inputs, sequence_length = dynamic_lstm_input_fn(batch_size, state_size,
1487                                                  max_steps)
1488  inputs_ta = tensor_array_ops.TensorArray(
1489      dtypes.float32, size=max_steps, element_shape=[batch_size, state_size])
1490  inputs_time_major = array_ops.transpose(inputs, [1, 0, 2])
1491  inputs_ta = inputs_ta.unstack(inputs_time_major)
1492  zeros = array_ops.zeros([state_size])
1493
1494  def loop_fn(i):
1495    sequence_length_i = array_ops.gather(sequence_length, i)
1496
1497    def body_fn(t, state, ta):
1498      inputs_t = array_ops.expand_dims(
1499          array_ops.gather(inputs_ta.read(t), i), 0)
1500      output, new_state = cell(inputs_t, state)
1501      output = array_ops.reshape(output, [-1])
1502      # TODO(agarwal): one optimization that dynamic_rnn uses is to avoid the
1503      # array_ops.where when t < min(sequence_length). Doing that requires
1504      # supporting tf.cond pfor conversion.
1505      done = t >= sequence_length_i
1506      output = array_ops.where(done, zeros, output)
1507      ta = ta.write(t, output)
1508      new_state = [
1509          array_ops.where(done, s, ns)
1510          for s, ns in zip(nest.flatten(state), nest.flatten(new_state))
1511      ]
1512      new_state = nest.pack_sequence_as(state, new_state)
1513      return t + 1, new_state, ta
1514
1515    def condition_fn(t, _, unused):
1516      del unused
1517      return t < max_steps
1518
1519    initial_state = cell.zero_state(1, dtypes.float32)
1520    _, state, ta = control_flow_ops.while_loop(condition_fn, body_fn, [
1521        0, initial_state,
1522        tensor_array_ops.TensorArray(dtypes.float32, max_steps)
1523    ])
1524
1525    new_state = [array_ops.reshape(x, [-1]) for x in nest.flatten(state)]
1526    new_state = nest.pack_sequence_as(initial_state, new_state)
1527    return ta.stack(), new_state
1528
1529  pfor_output = pfor_control_flow_ops.pfor(loop_fn, batch_size)
1530  tf_output = rnn.dynamic_rnn(
1531      cell,
1532      inputs,
1533      sequence_length=sequence_length,
1534      initial_state=cell.zero_state(batch_size, dtypes.float32))
1535  return pfor_output, tf_output
1536
1537
1538@test_util.run_all_in_graph_and_eager_modes
1539class WhileV2Test(PForTestCase):
1540
1541  def setUp(self):
1542    self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
1543    control_flow_v2_toggles.enable_control_flow_v2()
1544    super(WhileV2Test, self).setUp()
1545
1546  def tearDown(self):
1547    if not self._enabled:
1548      control_flow_v2_toggles.disable_control_flow_v2()
1549    super(WhileV2Test, self).tearDown()
1550
1551  def test_while_outside_loop(self):
1552
1553    def _f():
1554      return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
1555
1556    def loop_fn(i):
1557      return _f() + i
1558
1559    self._test_loop_fn(loop_fn, 3)
1560
1561  def test_invariant_while(self):
1562
1563    def loop_fn(_):
1564      return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
1565
1566    self._test_loop_fn(loop_fn, 3)
1567
1568  def test_invariant_while_with_control_dependency(self):
1569
1570    def loop_fn(i):
1571      with ops.control_dependencies([i]):
1572        return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1,
1573                                           [0])
1574
1575    self._test_loop_fn(loop_fn, 3)
1576
1577  def test_while_with_stateful_ops(self):
1578
1579    def loop_fn(_):
1580      j, _ = control_flow_ops.while_loop(
1581          lambda j, x: j < 4, lambda j, x:
1582          (j + 1, x + random_ops.random_uniform([])), [0, 0.])
1583      return j
1584
1585    self._test_loop_fn(loop_fn, 3)
1586
1587  def test_while_with_variable(self):
1588    v = resource_variable_ops.ResourceVariable(5.)
1589
1590    def loop_fn(_):
1591      _, output = control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
1592                                              (j + 1, x + v), [0, 0.])
1593      return output
1594
1595    self._test_loop_fn(loop_fn, 3)
1596
1597  def test_while_unstacked_condition(self):
1598
1599    def loop_fn(i):
1600      return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
1601                                         (j + 1, x + i), [0, 0])
1602
1603    self._test_loop_fn(loop_fn, 3)
1604
1605  def test_while(self):
1606    x = random_ops.random_uniform([3, 5])
1607    lengths = constant_op.constant([4, 0, 2])
1608
1609    def loop_fn(i):
1610      x_i = array_ops.gather(x, i)
1611      lengths_i = array_ops.gather(lengths, i)
1612
1613      return control_flow_ops.while_loop(
1614          lambda j, _: j < lengths_i, lambda j, t:
1615          (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
1616
1617    self._test_loop_fn(loop_fn, 3)
1618
1619  def test_while_change_input_invariance(self):
1620    # This tests cases where a loop invariant input to while has loop dependent
1621    # operations applied to it inside the while body.
1622    # It also test inputs that are passed through.
1623    def loop_fn(i):
1624      return control_flow_ops.while_loop(
1625          lambda j, *_: j < i, lambda j, x, y, z, w:
1626          (j + 1, x + i, y + x, z, w), [
1627              0,
1628              constant_op.constant(0),
1629              constant_op.constant(1), i,
1630              constant_op.constant(2)
1631          ])
1632
1633    self._test_loop_fn(loop_fn, 3)
1634
1635  def test_while_shape_invariants(self):
1636
1637    def loop_fn(i):
1638      return control_flow_ops.while_loop(
1639          lambda j, *_: j < 4,
1640          lambda j, x, y: (j + 1, x + i, y + 1),
1641          [0, constant_op.constant([0, 1]),
1642           constant_op.constant([2, 3])],
1643          shape_invariants=[
1644              None,
1645              tensor_shape.TensorShape([2]),
1646              tensor_shape.TensorShape([2])
1647          ])
1648
1649    self._test_loop_fn(loop_fn, 3)
1650
1651  def test_while_jacobian(self):
1652    # Note that we wrap the code below in a tf.function since we don't want the
1653    # while_loop call to be evaluated eagerly using a python loop.
1654    @def_function.function
1655    def _f(x, y, use_pfor):
1656      # out = x @ y @ y @ y @ y, where @ is matmul operator.
1657      _, out = control_flow_ops.while_loop(
1658          lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)),
1659          [0, x])
1660
1661      def loop_fn(i):
1662        out_i = array_ops.gather(out, i, axis=1)
1663        grad = gradient_ops.gradients(out_i, x)
1664        return array_ops.reshape(grad[0], [-1])
1665
1666      if use_pfor:
1667        return pfor_control_flow_ops.pfor(loop_fn, iters=3)
1668      else:
1669        return pfor_control_flow_ops.for_loop(
1670            loop_fn, iters=3, loop_fn_dtypes=out.dtype)
1671
1672    x = constant_op.constant(np.random.uniform(size=(1, 3)))
1673    y = constant_op.constant(np.random.uniform(size=(3, 3)))
1674    self.assertAllClose(_f(x, y, True), _f(x, y, False))
1675
1676  def test_scan(self):
1677    np.random.seed(seed=42)
1678    data = np.random.randn(3).astype(np.float32)
1679
1680    def log_prob(x):
1681      return math_ops.reduce_sum(functional_ops.scan_v2(
1682          lambda _, yi: (x - yi)**2,
1683          elems=data,
1684          initializer=constant_op.constant(0.)))
1685
1686    x = variables.Variable(array_ops.ones([2]))
1687    self.evaluate(x.initializer)
1688    v_log_prob = lambda x: pfor_control_flow_ops.vectorized_map(log_prob, x)
1689    theoretical, numerical = gradient_checker_v2.compute_gradient(
1690        v_log_prob, (x,), delta=1e-3)
1691    self.assertAllClose(theoretical, numerical, rtol=1e-2)
1692
1693
1694@test_util.run_all_in_graph_and_eager_modes
1695class NestedControlFlowTest(PForTestCase):
1696
1697  def setUp(self):
1698    self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
1699    control_flow_v2_toggles.enable_control_flow_v2()
1700    super(NestedControlFlowTest, self).setUp()
1701
1702  def tearDown(self):
1703    if not self._enabled:
1704      control_flow_v2_toggles.disable_control_flow_v2()
1705    super(NestedControlFlowTest, self).tearDown()
1706
1707  def _cond(self, f=None, split=0):
1708    if f is None:
1709      f = lambda x, y: (x, y)
1710
1711    def _f(x, y):
1712      return control_flow_ops.cond(y > split, lambda: f(x, y), lambda:
1713                                   (x + 1., y))
1714
1715    return _f
1716
1717  def _while(self, f=None):
1718    if f is None:
1719      f = lambda x, y: (x, y)
1720
1721    def _f(x, y):
1722      return control_flow_ops.while_loop(
1723          lambda j, _: j < y, lambda j, t:
1724          (j + 1, t + array_ops.gather(f(x, y)[0], j)), [0, x])[1], y
1725
1726    return _f
1727
1728  def _test_helper(self, f):
1729    x = random_ops.random_uniform([5, 5])
1730    y = constant_op.constant([4, -1, 2, -2, 2])
1731
1732    def loop_fn(i):
1733      x_i = array_ops.gather(x, i)
1734      y_i = array_ops.gather(y, i)
1735      return f(x_i, y_i)
1736
1737    self._test_loop_fn(loop_fn, 5)
1738
1739  def test_cond_while(self):
1740    self._test_helper(self._cond(self._while()))
1741
1742  def test_while_cond(self):
1743    self._test_helper(self._while(self._cond()))
1744
1745  def test_while_while(self):
1746    self._test_helper(self._while(self._while()))
1747
1748  def test_cond_cond(self):
1749    self._test_helper(self._cond(self._cond()))
1750
1751
1752@test_util.run_all_in_graph_and_eager_modes
1753@test_util.with_control_flow_v2
1754class StatelessIfTest(PForTestCase):
1755
1756  def test_loop_variant_cond(self):
1757    x = [1, 2, 3, 4, 5.]
1758    y = 2.5
1759
1760    @def_function.function
1761    def loop_fn(i):
1762      x_i = array_ops.gather(x, i)
1763      # Note that the output has a combination of then and else branches being
1764      # loop variant / invariant.
1765      return cond_v2.cond_v2(x_i < y, lambda: (y - x_i, y, 1., 2.), lambda:
1766                             (x_i - y, 0., y, 3.))
1767
1768    self._test_loop_fn(loop_fn, iters=5)
1769
1770  def test_loop_invariant_cond(self):
1771    x = [1, 2, 3, 4, 5.]
1772    y = 0.5
1773    z = random_ops.random_uniform([])
1774
1775    @def_function.function
1776    def loop_fn(i):
1777      x_i = array_ops.gather(x, i)
1778      # Note that the output has a combination of then and else branches being
1779      # loop variant / invariant.
1780      return cond_v2.cond_v2(z < y, lambda: (y - x_i, y, 1., 2.), lambda:
1781                             (x_i - y, 0., y, 3.))
1782
1783    self._test_loop_fn(loop_fn, iters=5)
1784
1785  def test_empty_branch(self):
1786    x = [1, 2, 3, 4, 5.]
1787    y = 6.
1788
1789    @def_function.function
1790    def loop_fn(i):
1791      x_i = array_ops.gather(x, i)
1792      return cond_v2.cond_v2(
1793          x_i < y,  # Note that else branch is empty.
1794          lambda: (y - x_i, y, 1., 2.),
1795          lambda: (x_i - y, 0., y, 3.))
1796
1797    self._test_loop_fn(loop_fn, iters=5)
1798
1799
1800@test_util.run_all_in_graph_and_eager_modes
1801@test_util.with_control_flow_v2
1802class IfTest(PForTestCase):
1803
1804  def test_read_var(self):
1805    self.skipTest("b/156438918")  # Flaky
1806
1807    x = [1, 2, 3, 4, 5.]
1808    y = 2.5
1809    z = resource_variable_ops.ResourceVariable(5.)
1810
1811    @def_function.function
1812    def loop_fn(i):
1813      x_i = array_ops.gather(x, i)
1814      return cond_v2.cond_v2(x_i < y, lambda: z - x_i, lambda: z + x_i)
1815
1816    self._test_loop_fn(loop_fn, iters=5)
1817
1818
1819class RNNTest(PForTestCase):
1820
1821  @test_util.run_v1_only("b/122612051")
1822  def test_dynamic_rnn(self):
1823    pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, 3, 5,
1824                                                   7)
1825    self.run_and_assert_equal(pfor_outputs, tf_outputs)
1826
1827  @test_util.run_v1_only("b/122612051")
1828  def test_dynamic_lstm(self):
1829    pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicLSTMCell, 3, 5,
1830                                                   7)
1831    self.run_and_assert_equal(pfor_outputs, tf_outputs)
1832
1833
1834# TODO(agarwal): benchmark numbers on GPU for graphs based on while_loop
1835# conversion don't look good. Some of it seems like lot of copies between host
1836# and device. Optimize that.
1837class Benchmarks(test.Benchmark):
1838
1839  def _run(self, targets, iters, name=None):
1840
1841    def _done(t):
1842      # Note that we don't use tf.control_dependencies since that will not make
1843      # sure that the computation on GPU has actually finished. So we fetch the
1844      # first element of the output, and assume that this will not be called on
1845      # empty tensors.
1846      return array_ops.gather(array_ops.reshape(t, [-1]), 0)
1847
1848    targets = [_done(x) for x in nest.flatten(targets)]
1849    sess = session.Session()
1850    with sess:
1851      init = variables.global_variables_initializer()
1852      sess.run(init)
1853      run_fn = sess.make_callable(targets)
1854      run_fn()  # Warm up
1855      begin = time.time()
1856      for _ in range(iters):
1857        run_fn()
1858      end = time.time()
1859    avg_time_ms = 1000 * (end - begin) / iters
1860    self.report_benchmark(iters=iters, wall_time=avg_time_ms, name=name)
1861    return avg_time_ms
1862
1863  def benchmark_sess_run_overhead(self):
1864    with ops.Graph().as_default():
1865      x = constant_op.constant(1.0)
1866      self._run(x, 10000, name="session_run_overhead")
1867
1868  def benchmark_add(self):
1869    with ops.Graph().as_default():
1870      n = 256
1871      params = 1000
1872      x = random_ops.random_normal([n, params])
1873      y = random_ops.random_normal([n, params])
1874
1875      def loop_fn(i):
1876        x_i = array_ops.gather(x, i)
1877        y_i = array_ops.gather(y, i)
1878        return x_i + y_i
1879
1880      pfor_outputs = pfor_control_flow_ops.pfor(loop_fn, n)
1881      while_outputs = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, n)
1882      manual = x + y
1883
1884      self._run(manual, 1000, name="manual_add")
1885      self._run(pfor_outputs, 1000, name="pfor_add")
1886      self._run(while_outputs, 100, name="while_add")
1887
1888  def benchmark_matmul(self):
1889    with ops.Graph().as_default():
1890      n = 1024
1891      params = 1000
1892      x = random_ops.random_normal([n, params])
1893      y = random_ops.random_normal([params, params])
1894
1895      def loop_fn(i):
1896        x_i = array_ops.expand_dims(array_ops.gather(x, i), 0)
1897        return math_ops.matmul(x_i, y)
1898
1899      pfor_outputs = pfor_control_flow_ops.pfor(loop_fn, n)
1900      while_outputs = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, n)
1901      manual = math_ops.matmul(x, y)
1902
1903      self._run(manual, 1000, name="manual_matmul")
1904      self._run(pfor_outputs, 1000, name="pfor_matmul")
1905      self._run(while_outputs, 100, name="while_matmul")
1906
1907  def benchmark_map_fn(self):
1908    with ops.Graph().as_default():
1909      b = 256
1910      params = 1000
1911      inp = random_ops.random_normal((b, params))
1912      fn = lambda x: x * x
1913
1914      def pfor_map_fn(f, x):
1915        return pfor_control_flow_ops.pfor(lambda i: f(array_ops.gather(x, i)),
1916                                          array_ops.shape(x)[0])
1917
1918      map_output = map_fn.map_fn(fn, inp)
1919      pfor_output = pfor_map_fn(fn, inp)
1920
1921      self._run(map_output, 100, name="tf_map_fn")
1922      self._run(pfor_output, 100, name="pfor_map_fn")
1923
1924  def benchmark_basic_while(self):
1925    with ops.Graph().as_default():
1926
1927      def loop_fn(i):
1928        _, s = control_flow_ops.while_loop(lambda t, x: t < i, lambda t, x:
1929                                           (t + 1, x + i), [0, 0])
1930        return s
1931
1932      iters = 50
1933      pfor_output = pfor_control_flow_ops.pfor(loop_fn, iters)
1934      for_loop_output = pfor_control_flow_ops.for_loop(loop_fn, dtypes.int32,
1935                                                       iters)
1936      self._run(pfor_output, 100, name="pfor_basic")
1937      self._run(for_loop_output, 100, name="for_loop_basic")
1938
1939  def benchmark_dynamic_rnn(self):
1940    with ops.Graph().as_default():
1941      pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, 128,
1942                                                     512, 16)
1943      self._run(pfor_outputs, 100, name="pfor_rnn")
1944      self._run(tf_outputs, 100, name="tf_rnn")
1945
1946  def benchmark_reduction(self):
1947    n = 1024
1948    with ops.Graph().as_default():
1949      x = random_ops.random_uniform([n, n])
1950      w = random_ops.random_uniform([n, n])
1951
1952      def loop_fn(i, pfor_config):
1953        x_i = array_ops.gather(x, i)
1954        return math_ops.reduce_sum(
1955            math_ops.matmul(pfor_config.reduce_concat(x_i), w))
1956
1957      # Note that output_reduction will be tiled, so there may be some minor
1958      # overheads compared to output_no_reduction.
1959      output_reduction = pfor_control_flow_ops.pfor(loop_fn, n)
1960      output_no_reduction = math_ops.reduce_sum(math_ops.matmul(x, w))
1961      # Benchmark to test that reduction does not add overhead and its output is
1962      # treated as loop invariant.
1963      self._run(output_reduction, 30, name="matmul_reduction")
1964      self._run(output_no_reduction, 30, name="matmul_no_reduction")
1965
1966
1967class SparseTest(PForTestCase):
1968
1969  @test_util.run_v1_only("b/122612051")
1970  def test_var_loop_len(self):
1971    num_iters = array_ops.placeholder(dtypes.int32)
1972
1973    def loop_fn(_):
1974      return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6],
1975                                        [3])  # [0, 2, 0]
1976
1977    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
1978    with self.cached_session() as sess:
1979      sess.run(pfor, feed_dict={num_iters: 3})
1980
1981  @test_util.run_v1_only("b/122612051")
1982  def test_sparse_result_none_stacked(self):
1983    num_iters = 10
1984
1985    def loop_fn(_):
1986      return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6],
1987                                        [3])  # [0, 2, 0]
1988
1989    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
1990
1991    indices = [[i, j] for i in range(num_iters) for j in range(3)]
1992    values = [4, 5, 6] * num_iters
1993    dense_shapes = [num_iters, 3]
1994    # Expected result: [[4, 5, 6], [4, 5, 6], [4, 5, 6], ...]
1995    manual = sparse_tensor.SparseTensor(indices, values, dense_shapes)
1996    self.run_and_assert_equal(pfor, manual)
1997
1998  @test_util.run_v1_only("b/122612051")
1999  def test_sparse_result_all_stacked(self):
2000    num_iters = 10
2001
2002    def loop_fn(i):
2003      i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
2004      indices = array_ops.expand_dims(i, 0)
2005      return sparse_tensor.SparseTensor(indices, i, i + 1)  # [0, ..., 0, i]
2006
2007    # Expected result: [[0], [0, 1], [0, 0, 2], [0, 0, 0, 3], ...]
2008    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2009    manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)],
2010                                        list(range(num_iters)),
2011                                        (num_iters, num_iters))
2012    self.run_and_assert_equal(pfor, manual)
2013
2014  @test_util.run_v1_only("b/122612051")
2015  def test_sparse_result_indices_stacked(self):
2016    num_iters = 10
2017
2018    def loop_fn(i):
2019      i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
2020      indices = array_ops.expand_dims(i, 0)
2021      return sparse_tensor.SparseTensor(indices, [1], [num_iters])
2022
2023    # Expected result: identity matrix size num_iters * num_iters
2024    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2025    manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)],
2026                                        [1] * num_iters, (num_iters, num_iters))
2027    self.run_and_assert_equal(pfor, manual)
2028
2029  @test_util.run_v1_only("b/122612051")
2030  def test_sparse_result_values_stacked(self):
2031    num_iters = 10
2032
2033    def loop_fn(i):
2034      i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
2035      return sparse_tensor.SparseTensor([[0]], i, [num_iters])  # [i, 0, ..., 0]
2036
2037    # Expected result: [[1, 0, ...], [2, 0, ...], [3, 0, ...], ...]
2038    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2039    manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)],
2040                                        list(range(num_iters)),
2041                                        (num_iters, num_iters))
2042    self.run_and_assert_equal(pfor, manual)
2043
2044  @test_util.run_v1_only("b/122612051")
2045  def test_sparse_result_shapes_stacked(self):
2046    num_iters = 10
2047
2048    def loop_fn(i):
2049      i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
2050      return sparse_tensor.SparseTensor([[0]], [1], i + 1)  # [1, 0, ..., 0]
2051
2052    # Expected result: [[1, 0, 0, ...], [1, 0, 0, ...], ...]
2053    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2054    manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)],
2055                                        [1] * num_iters, (num_iters, num_iters))
2056    self.run_and_assert_equal(pfor, manual)
2057
2058  @test_util.run_v1_only("b/122612051")
2059  def test_sparse_result_shapes_stacked_2D(self):
2060    num_iters = 10
2061
2062    def loop_fn(i):
2063      i = array_ops.expand_dims(math_ops.cast(i + 1, dtypes.int64), 0)
2064      shape = array_ops.concat([i, i], 0)
2065      return sparse_tensor.SparseTensor([[0, 0]], [1], shape)  # [1, 0, ..., 0]
2066
2067    # Expected result: [[[1, 0, ...], [0, ..., 0], [0, ..., 0], ...], ...]
2068    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2069    manual = sparse_tensor.SparseTensor([[i, 0, 0] for i in range(num_iters)],
2070                                        [1] * num_iters,
2071                                        (num_iters, num_iters, num_iters))
2072    self.run_and_assert_equal(pfor, manual)
2073
2074
2075# Dummy CompositeTensor to test CompositeTensor support.
2076class Particle(composite_tensor.CompositeTensor):
2077  """A (batch of) particles each defined by a mass and a scalar velocity."""
2078
2079  def __init__(self, mass, velocity):
2080    mass = ops.convert_to_tensor(mass)
2081    velocity = ops.convert_to_tensor(velocity)
2082    self.shape = array_ops.broadcast_static_shape(mass.shape, velocity.shape)
2083    self.mass = mass
2084    self.velocity = velocity
2085
2086  @property
2087  def _type_spec(self):
2088    return ParticleSpec(
2089        type_spec.type_spec_from_value(self.mass),
2090        type_spec.type_spec_from_value(self.velocity))
2091
2092
2093class ParticleSpec(type_spec.BatchableTypeSpec):
2094
2095  def __init__(self, mass, velocity):
2096    self.shape = array_ops.broadcast_static_shape(
2097        mass.shape, velocity.shape)
2098    self.mass = mass
2099    self.velocity = velocity
2100
2101  def _serialize(self):
2102    return (self.mass, self.velocity)
2103
2104  @property
2105  def value_type(self):
2106    return Particle
2107
2108  @property
2109  def _component_specs(self):
2110    return (self.mass, self.velocity)
2111
2112  def _to_components(self, value):
2113    return (value.mass, value.velocity)
2114
2115  def _from_components(self, components):
2116    return Particle(*components)
2117
2118  def _pad_shape_to_full_rank(self, s):
2119    """Pad component shapes with 1's so all components have the same rank."""
2120    return tensor_shape.TensorShape(
2121        [1] * (self.shape.ndims - s.ndims)).concatenate(s)
2122
2123  def _batch(self, batch_size):
2124    return ParticleSpec(
2125        mass=tensor_spec.TensorSpec(
2126            dtype=self.mass.dtype,
2127            shape=tensor_shape.TensorShape([batch_size]).concatenate(
2128                self._pad_shape_to_full_rank(self.mass.shape))),
2129        velocity=tensor_spec.TensorSpec(
2130            dtype=self.velocity.dtype,
2131            shape=tensor_shape.TensorShape([batch_size]).concatenate(
2132                self._pad_shape_to_full_rank(self.velocity.shape))))
2133
2134  def _unbatch(self):
2135    return ParticleSpec(
2136                tensor_spec.TensorSpec(dtype=self.mass.dtype,
2137                                       shape=self.mass.shape[1:]),
2138                tensor_spec.TensorSpec(dtype=self.velocity.dtype,
2139                                       shape=self.velocity.shape[1:]))
2140
2141  def _to_tensor_list(self, value):
2142    return [array_ops.reshape(
2143                value.mass,
2144                self._pad_shape_to_full_rank(value.mass.shape)),
2145            array_ops.reshape(
2146                value.velocity,
2147                self._pad_shape_to_full_rank(value.velocity.shape))]
2148
2149
2150class CompositeTensorTest(PForTestCase, parameterized.TestCase):
2151
2152  @parameterized.parameters((None,), (3,))
2153  def test_create_composite_inside_loop(self, parallel_iterations):
2154    num_particles = 10
2155    velocities = random_ops.random_uniform([num_particles])
2156    particles = pfor_control_flow_ops.pfor(
2157        # Build a batch of particles all with the same mass.
2158        lambda i: Particle(mass=4., velocity=array_ops.gather(velocities, i)),
2159        num_particles,
2160        parallel_iterations=parallel_iterations)
2161    particles_mass, particles_velocity, velocities = self.evaluate(
2162        (particles.mass, particles.velocity, velocities))
2163    self.assertAllEqual(particles_mass, 4. * np.ones([num_particles]))
2164    self.assertAllEqual(particles_velocity, velocities)
2165
2166  @parameterized.parameters((None,), (3,))
2167  def test_composite_is_converted_to_batched_tensor(
2168      self, parallel_iterations):
2169    particles = pfor_control_flow_ops.pfor(
2170        lambda _: Particle(mass=random_ops.random_uniform([3]),  # pylint: disable=g-long-lambda
2171                           velocity=random_ops.random_uniform([5, 3])),
2172        4,
2173        parallel_iterations=parallel_iterations)
2174    # Naively batching the component shapes would give `[4, 3]` and `[4, 5, 3]`
2175    # which have no consistent broadcast shape.
2176    self.assertTrue(particles.mass.shape, [4, 1, 3])
2177    self.assertAllEqual(particles.velocity.shape, [4, 5, 3])
2178
2179  def test_vectorized_map_gathers_composite_tensors(self):
2180    particles = Particle(mass=[1., 2., 3., 4., 5.],
2181                         velocity=[1., 2., 3., 4., 5.])
2182    self.assertAllEqual(
2183        pfor_control_flow_ops.vectorized_map(
2184            lambda x: x.mass * x.velocity, particles),
2185        particles.mass * particles.velocity)
2186
2187  def test_vectorized_map_of_ragged_tensors(self):
2188    # Vmap should be able to handle ragged Tensors as long as they're not
2189    # *actually* ragged.
2190    ragged = ragged_tensor.RaggedTensor.from_uniform_row_length(
2191        ragged_tensor.RaggedTensor.from_row_lengths(
2192            values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
2193            row_lengths=[3, 3, 3, 3]),
2194        uniform_row_length=2)  # Overall shape [2, 2, 3].
2195    self.assertAllEqual(
2196        pfor_control_flow_ops.vectorized_map(
2197            lambda x: x.to_tensor(shape=[2, 3]), ragged),
2198        ragged.to_tensor(shape=[2, 2, 3]))
2199
2200
2201class ParsingTest(PForTestCase):
2202
2203  def test_decode_csv(self):
2204    csv_tensor = constant_op.constant([["1:2:3"], ["::"], ["7:8:9"]])
2205    kwargs = {"record_defaults": [[10], [20], [30]], "field_delim": ":"}
2206
2207    def loop_fn(i):
2208      line = array_ops.gather(csv_tensor, i)
2209      return parsing_ops.decode_csv(line, **kwargs)
2210
2211    self._test_loop_fn(loop_fn, iters=3)
2212
2213  @test_util.run_v1_only("b/122612051")
2214  def test_parse_single_example(self):
2215
2216    def _int64_feature(*values):
2217      return feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=values))
2218
2219    def _bytes_feature(*values):
2220      return feature_pb2.Feature(
2221          bytes_list=feature_pb2.BytesList(
2222              value=[v.encode("utf-8") for v in values]))
2223
2224    examples = constant_op.constant([
2225        example_pb2.Example(
2226            features=feature_pb2.Features(
2227                feature={
2228                    "dense_int": _int64_feature(i),
2229                    "dense_str": _bytes_feature(str(i)),
2230                    "sparse_int": _int64_feature(i, i * 2, i * 4, i * 8),
2231                    "sparse_str": _bytes_feature(*["abc"] * i)
2232                })).SerializeToString() for i in range(10)
2233    ])
2234
2235    features = {
2236        "dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0),
2237        "dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""),
2238        "sparse_int": parsing_ops.VarLenFeature(dtypes.int64),
2239        "sparse_str": parsing_ops.VarLenFeature(dtypes.string),
2240    }
2241
2242    def loop_fn(i):
2243      example_proto = array_ops.gather(examples, i)
2244      f = parsing_ops.parse_single_example(example_proto, features)
2245      return f
2246
2247    pfor = pfor_control_flow_ops.pfor(loop_fn, iters=10)
2248    manual = parsing_ops.parse_example(examples, features)
2249    self.run_and_assert_equal(pfor, manual)
2250
2251
2252class PartitionedCallTest(PForTestCase):
2253
2254  def test_simple(self):
2255
2256    @def_function.function
2257    def f(x):
2258      return math_ops.square(x) + 1
2259
2260    z = random_ops.random_uniform([4])
2261
2262    def loop_fn(i):
2263      return f(array_ops.gather(z, i))
2264
2265    self._test_loop_fn(loop_fn, 4)
2266
2267  def test_nested_calls(self):
2268
2269    @def_function.function
2270    def inner(x):
2271      return math_ops.square(x)
2272
2273    @def_function.function
2274    def outer(y):
2275      return math_ops.reduce_sum(inner(y)) + 2
2276
2277    z = random_ops.random_uniform([4, 2])
2278
2279    def loop_fn(i):
2280      return outer(array_ops.gather(z, i))
2281
2282    self._test_loop_fn(loop_fn, 4)
2283
2284  def test_nested_definition(self):
2285
2286    @def_function.function
2287    def outer(y):
2288
2289      @def_function.function
2290      def inner(x):
2291        return math_ops.square(x) + 1
2292
2293      return math_ops.reduce_sum(inner(y)) + 2
2294
2295    z = random_ops.random_uniform([4, 2])
2296
2297    def loop_fn(i):
2298      return outer(array_ops.gather(z, i))
2299
2300    self._test_loop_fn(loop_fn, 4)
2301
2302  def test_gradients(self):
2303
2304    @def_function.function
2305    def f(x):
2306      return math_ops.square(x) + 1
2307
2308    z = random_ops.random_uniform([4, 2])
2309
2310    def loop_fn(i):
2311      z_i = array_ops.gather(z, i)
2312      with backprop.GradientTape() as g:
2313        g.watch(z_i)
2314        out = f(z_i)
2315      return out, g.gradient(out, z_i)
2316
2317    self._test_loop_fn(loop_fn, 4)
2318
2319  def test_stateful_with_gradients(self):
2320
2321    z = random_ops.random_uniform([4, 2])
2322    v = variables.Variable(z[0])
2323
2324    @def_function.function
2325    def f(x):
2326      return math_ops.square(x) + v + 1
2327
2328    def loop_fn(i):
2329      z_i = array_ops.gather(z, i)
2330      with backprop.GradientTape() as g:
2331        g.watch(z_i)
2332        out = f(z_i)
2333      return out, g.gradient(out, z_i)
2334
2335    self._test_loop_fn(loop_fn, 4)
2336
2337
2338class SpectralTest(PForTestCase, parameterized.TestCase):
2339
2340  @parameterized.parameters(
2341      (fft_ops.fft,),
2342      (fft_ops.fft2d,),
2343      (fft_ops.fft3d,),
2344      (fft_ops.ifft,),
2345      (fft_ops.ifft2d,),
2346      (fft_ops.ifft3d,),
2347  )
2348  def test_fft(self, op_func):
2349    shape = [2, 3, 4, 3, 4]
2350    x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape)
2351
2352    def loop_fn(i):
2353      x_i = array_ops.gather(x, i)
2354      return op_func(x_i)
2355
2356    self._test_loop_fn(loop_fn, 2)
2357
2358  @parameterized.parameters(
2359      (fft_ops.rfft,),
2360      (fft_ops.rfft2d,),
2361      (fft_ops.rfft3d,),
2362  )
2363  def test_rfft(self, op_func):
2364    for dtype in (dtypes.float32, dtypes.float64):
2365      x = random_ops.random_uniform([2, 3, 4, 3, 4], dtype=dtype)
2366
2367      # pylint: disable=cell-var-from-loop
2368      def loop_fn(i):
2369        x_i = array_ops.gather(x, i)
2370        return op_func(x_i)
2371
2372      # pylint: enable=cell-var-from-loop
2373
2374      self._test_loop_fn(loop_fn, 2)
2375
2376  @parameterized.parameters(
2377      (fft_ops.irfft,),
2378      (fft_ops.irfft2d,),
2379      (fft_ops.irfft3d,),
2380  )
2381  def test_irfft(self, op_func):
2382    if config.list_physical_devices("GPU"):
2383      # TODO(b/149957923): The test is flaky
2384      self.skipTest("b/149957923: irfft vectorization flaky")
2385    for dtype in (dtypes.complex64, dtypes.complex128):
2386      shape = [2, 3, 4, 3, 4]
2387      x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape)
2388      x = math_ops.cast(x, dtype=dtype)
2389
2390      # pylint: disable=cell-var-from-loop
2391      def loop_fn(i):
2392        x_i = array_ops.gather(x, i)
2393        return op_func(x_i)
2394
2395      # pylint: enable=cell-var-from-loop
2396
2397      self._test_loop_fn(loop_fn, 2)
2398
2399
2400class VariableTest(PForTestCase):
2401
2402  def test_create_variable_once(self):
2403    x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32)
2404    y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32)
2405    a_var = []
2406
2407    def f(z):
2408      if not a_var:
2409        a_var.append(variables.Variable(lambda: y, name="a"))
2410      return math_ops.matmul(z, a_var[0] / 16)
2411
2412    pfor_control_flow_ops.vectorized_map(f, x)
2413
2414  @test_util.run_v2_only
2415  def test_create_variable_repeated(self):
2416    x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32)
2417    y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32)
2418
2419    def f(z):
2420      a_var = variables.Variable(lambda: y, name="a") / 4
2421      return math_ops.matmul(z, a_var / 16)
2422
2423    # Note that this error is only raised under v2 behavior.
2424    with self.assertRaisesRegex(
2425        ValueError,
2426        "tf.function-decorated function tried to create variables on non-first"
2427    ):
2428      pfor_control_flow_ops.vectorized_map(f, x)
2429
2430  @test_util.run_all_in_graph_and_eager_modes
2431  def test_variable_shape(self):
2432    v = resource_variable_ops.ResourceVariable([1, 2])
2433
2434    def loop_fn(_):
2435      return resource_variable_ops.variable_shape(v.handle)
2436
2437    self._test_loop_fn(loop_fn, 2)
2438
2439
2440if __name__ == "__main__":
2441  test.main()
2442