1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for pfor and for_loop.""" 16# pylint: disable=g-direct-tensorflow-import 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23import time 24 25from absl.testing import parameterized 26import numpy as np 27 28from tensorflow.core.example import example_pb2 29from tensorflow.core.example import feature_pb2 30from tensorflow.python.client import session 31from tensorflow.python.eager import backprop 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.framework import composite_tensor 35from tensorflow.python.framework import config 36from tensorflow.python.framework import constant_op 37from tensorflow.python.framework import dtypes 38from tensorflow.python.framework import indexed_slices 39from tensorflow.python.framework import ops 40from tensorflow.python.framework import sparse_tensor 41from tensorflow.python.framework import tensor_shape 42from tensorflow.python.framework import tensor_spec 43from tensorflow.python.framework import test_util 44from tensorflow.python.framework import type_spec 45from tensorflow.python.ops import array_ops 46from tensorflow.python.ops import bitwise_ops 47from tensorflow.python.ops import cond_v2 48from tensorflow.python.ops import control_flow_ops 49from tensorflow.python.ops import control_flow_v2_toggles 50from tensorflow.python.ops import data_flow_ops 51from tensorflow.python.ops import functional_ops 52from tensorflow.python.ops import gen_dataset_ops 53from tensorflow.python.ops import gen_list_ops 54from tensorflow.python.ops import gen_nn_ops 55from tensorflow.python.ops import gradient_checker_v2 56from tensorflow.python.ops import gradients as gradient_ops 57from tensorflow.python.ops import image_ops 58from tensorflow.python.ops import list_ops 59from tensorflow.python.ops import logging_ops 60from tensorflow.python.ops import map_fn 61from tensorflow.python.ops import math_ops 62from tensorflow.python.ops import nn 63from tensorflow.python.ops import parsing_ops 64from tensorflow.python.ops import random_ops 65from tensorflow.python.ops import resource_variable_ops 66from tensorflow.python.ops import rnn 67from tensorflow.python.ops import rnn_cell 68from tensorflow.python.ops import stateless_random_ops 69from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import 70from tensorflow.python.ops import tensor_array_ops 71from tensorflow.python.ops import variables 72from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops 73from tensorflow.python.ops.parallel_for.test_util import PForTestCase 74from tensorflow.python.ops.ragged import ragged_tensor 75from tensorflow.python.ops.signal import fft_ops 76from tensorflow.python.platform import test 77from tensorflow.python.util import nest 78 79 80@test_util.run_all_in_graph_and_eager_modes 81@test_util.with_control_flow_v2 82class PForTest(PForTestCase): 83 84 def test_op_conversion_fallback_to_while_loop(self): 85 # Note that we used top_k op for this test. If a converter gets defined for 86 # it, we will need to find another op for which a converter has not been 87 # defined. 88 x = random_ops.random_uniform([3, 2, 4]) 89 90 def loop_fn(i): 91 x_i = array_ops.gather(x, i) 92 return nn.top_k(x_i) 93 94 with self.assertRaisesRegex(ValueError, "No pfor vectorization"): 95 self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False) 96 self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True) 97 98 def test_parallel_iterations(self): 99 for parallel_iterations in [2, 3, 8, 10]: 100 x = random_ops.random_uniform([8, 3]) 101 102 # pylint: disable=cell-var-from-loop 103 def loop_fn(i): 104 return array_ops.gather(x, i) 105 106 # pylint: enable=cell-var-from-loop 107 108 self._test_loop_fn(loop_fn, 8, parallel_iterations=parallel_iterations) 109 self._test_loop_fn( 110 loop_fn, 111 4 * constant_op.constant(2), 112 parallel_iterations=parallel_iterations) 113 114 def test_parallel_iterations_preserves_static_shape(self): 115 for parallel_iterations in [2, 3, 8, 10]: 116 x = pfor_control_flow_ops.pfor( 117 lambda _: random_ops.random_uniform([2, 3]), 118 8, 119 parallel_iterations=parallel_iterations) 120 self.assertAllEqual(x.shape, [8, 2, 3]) 121 122 def test_parallel_iterations_zero(self): 123 with self.assertRaisesRegex(ValueError, "positive integer"): 124 pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=0) 125 with self.assertRaisesRegex(TypeError, "positive integer"): 126 pfor_control_flow_ops.for_loop( 127 lambda i: 1, dtypes.int32, 8, parallel_iterations=0) 128 129 def test_parallel_iterations_one(self): 130 with self.assertRaisesRegex(ValueError, "Use for_loop instead"): 131 pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=1) 132 133 def test_vectorized_map(self): 134 135 def compute(x): 136 return math_ops.reduce_mean(x, axis=0, keepdims=True) 137 138 result = pfor_control_flow_ops.vectorized_map(compute, 139 array_ops.ones((10, 5, 3))) 140 self.run_and_assert_equal(result, array_ops.ones((10, 1, 3))) 141 142 def test_vectorized_map_with_dynamic_shape(self): 143 144 def compute(x): 145 return math_ops.reduce_mean(x, axis=0, keepdims=True) 146 147 x = array_ops.placeholder_with_default( 148 array_ops.ones((10, 5, 3)), shape=None) 149 result = pfor_control_flow_ops.vectorized_map(compute, x) 150 self.run_and_assert_equal(result, array_ops.ones((10, 1, 3))) 151 152 def test_vectorized_map_broadcasts_unit_dimensions(self): 153 convert_with_static_shape = ops.convert_to_tensor 154 convert_with_dynamic_shape = ( 155 lambda x: array_ops.placeholder_with_default(x, shape=None)) 156 157 for convert in (convert_with_static_shape, convert_with_dynamic_shape): 158 a = convert([3.1]) 159 b = convert([-2., 6., 9.]) 160 161 # One elem with leading unit dimension. 162 a_plus_1 = pfor_control_flow_ops.vectorized_map(lambda a: a + 1, a) 163 self.assertAllEqual(*self.evaluate((a_plus_1, a + 1))) 164 165 # Two elems, both with leading unit dimension. 166 a_plus_a = pfor_control_flow_ops.vectorized_map(sum, (a, a)) 167 self.assertAllEqual(*self.evaluate((a_plus_a, a + a))) 168 169 # Elem w/ unit dimension broadcast against elem with batch dim. 170 a_plus_b = pfor_control_flow_ops.vectorized_map(sum, (a, b)) 171 self.assertAllEqual(*self.evaluate((a_plus_b, a + b))) 172 173 def test_vectorized_map_example_1(self): 174 175 def outer_product(a): 176 return math_ops.tensordot(a, a, 0) 177 178 batch_size = 100 179 a = array_ops.ones((batch_size, 32, 32)) 180 c = pfor_control_flow_ops.vectorized_map(outer_product, a) 181 self.assertAllEqual((batch_size, 32, 32, 32, 32), c.shape) 182 183 def test_disable_tf_function(self): 184 def_function.run_functions_eagerly(True) 185 # vectorized_map should ignore disabling tf.functions 186 self.assertTrue(def_function.functions_run_eagerly()) 187 self.assertAllEqual([0, 1, 4, 9], 188 pfor_control_flow_ops.vectorized_map( 189 lambda x: x * x, math_ops.range(4))) 190 self.assertTrue(def_function.functions_run_eagerly()) 191 def_function.run_functions_eagerly(False) 192 193 194@test_util.run_all_in_graph_and_eager_modes 195class IndexedSlicesTest(PForTestCase): 196 197 def test_indexed_slices(self): 198 199 def loop_fn(i): 200 return indexed_slices.IndexedSlices( 201 indices=i, values=array_ops.reshape(i, [1]), dense_shape=[3, 1]) 202 203 self._test_loop_fn(loop_fn, 2) 204 205 def test_indexed_slices_components(self): 206 207 def loop_fn(i): 208 slices = indexed_slices.IndexedSlices( 209 indices=i, values=array_ops.reshape(i, [1]), dense_shape=[3, 1]) 210 # Note that returning the components inside the slice avoids 211 # densification, which may be more efficient. 212 return slices.values, slices.indices 213 214 self._test_loop_fn(loop_fn, 2) 215 216 217@test_util.run_all_in_graph_and_eager_modes 218class ReductionTest(PForTestCase): 219 220 def test_reduce(self): 221 222 def reduce_fn(p, q): 223 return math_ops.reduce_mean(p + q, axis=0) 224 225 x = random_ops.random_uniform([4, 3, 2]) 226 y = random_ops.random_uniform([4, 3, 2]) 227 228 def loop_fn(i, pfor_config): 229 x_i = array_ops.gather(x, i) 230 y_i = array_ops.gather(y, i) 231 reduced = pfor_config.reduce(reduce_fn, x_i, y_i) 232 return reduced + x_i 233 234 output = pfor_control_flow_ops.pfor(loop_fn, 4) 235 ans = reduce_fn(x, y) + x 236 output_val, ans_val = self.evaluate([output, ans]) 237 self.assertAllClose(ans_val, output_val) 238 239 def test_reduce_concat(self): 240 x = random_ops.random_uniform([8, 3]) 241 242 def loop_fn(i, pfor_config): 243 x_i = array_ops.gather(x, i) 244 vectorized_value = pfor_config.reduce_concat(x_i) 245 mean_value = math_ops.reduce_mean(vectorized_value, axis=0) 246 return x_i - mean_value 247 248 output = pfor_control_flow_ops.pfor(loop_fn, 8) 249 ans = x - math_ops.reduce_mean(x, axis=0) 250 output_val, ans_val = self.evaluate([output, ans]) 251 self.assertAllClose(ans_val, output_val) 252 253 def test_reduce_mean(self): 254 x = random_ops.random_uniform([8, 3]) 255 256 def loop_fn(i, pfor_config): 257 x_i = array_ops.gather(x, i) 258 return x_i - pfor_config.reduce_mean(x_i) 259 260 output = pfor_control_flow_ops.pfor(loop_fn, 8) 261 ans = x - math_ops.reduce_mean(x, axis=0) 262 output_val, ans_val = self.evaluate([output, ans]) 263 self.assertAllClose(ans_val, output_val) 264 265 def test_reduce_sum(self): 266 x = random_ops.random_uniform([8, 3]) 267 268 def loop_fn(i, pfor_config): 269 x_i = array_ops.gather(x, i) 270 return x_i - pfor_config.reduce_sum(x_i) 271 272 output = pfor_control_flow_ops.pfor(loop_fn, 8) 273 ans = x - math_ops.reduce_sum(x, axis=0) 274 output_val, ans_val = self.evaluate([output, ans]) 275 self.assertAllClose(ans_val, output_val) 276 277 def test_reduce_class(self): 278 x = random_ops.random_uniform([8, 3]) 279 280 class LoopFn(object): 281 282 def __init__(self): 283 pass 284 285 def __call__(self, i, pfor_config): 286 x_i = array_ops.gather(x, i) 287 return x_i - pfor_config.reduce_mean(x_i) 288 289 output = pfor_control_flow_ops.pfor(LoopFn(), 8) 290 ans = x - math_ops.reduce_mean(x, axis=0) 291 output_val, ans_val = self.evaluate([output, ans]) 292 self.assertAllClose(ans_val, output_val) 293 294 def test_reduce_functools_partial(self): 295 x = random_ops.random_uniform([8, 3]) 296 297 def fn(i, pfor_config, dummy=None): 298 del dummy 299 x_i = array_ops.gather(x, i) 300 return x_i - pfor_config.reduce_mean(x_i) 301 302 loop_fn = functools.partial(fn, dummy=1) 303 output = pfor_control_flow_ops.pfor(loop_fn, 8) 304 ans = x - math_ops.reduce_mean(x, axis=0) 305 output_val, ans_val = self.evaluate([output, ans]) 306 self.assertAllClose(ans_val, output_val) 307 308 def test_parallel_iterations(self): 309 x = random_ops.random_uniform([8, 3]) 310 311 def loop_fn(i, pfor_config): 312 x_i = array_ops.gather(x, i) 313 return pfor_config.reduce_sum(x_i) 314 315 with self.assertRaisesRegex(ValueError, 316 "parallel_iterations currently unsupported"): 317 pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2) 318 319 def test_var_loop_len(self): 320 if context.executing_eagerly(): 321 self.skipTest("Variable length not possible under eager execution.") 322 323 x = random_ops.random_uniform([8, 3]) 324 325 def loop_fn(i, pfor_config): 326 return pfor_config.reduce_sum(array_ops.gather(x, i)) 327 328 num_iters = array_ops.placeholder(dtypes.int32) 329 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 330 with self.cached_session() as sess: 331 sess.run(pfor, feed_dict={num_iters: 8}) 332 333 334@test_util.run_all_in_graph_and_eager_modes 335class BitwiseTest(PForTestCase): 336 337 def test_unary_cwise(self): 338 for op in [bitwise_ops.invert]: 339 x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32) 340 341 # pylint: disable=cell-var-from-loop 342 def loop_fn(i): 343 x1 = array_ops.gather(x, i) 344 return op(x1) 345 346 # pylint: enable=cell-var-from-loop 347 348 self._test_loop_fn(loop_fn, 3) 349 350 def test_binary_cwise(self): 351 binary_ops = [ 352 bitwise_ops.bitwise_and, 353 bitwise_ops.bitwise_or, 354 bitwise_ops.bitwise_xor, 355 bitwise_ops.left_shift, 356 bitwise_ops.right_shift, 357 ] 358 for op in binary_ops: 359 x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32) 360 y = random_ops.random_uniform([3, 5], maxval=10, dtype=dtypes.int32) 361 362 output_dtypes = [] 363 364 # pylint: disable=cell-var-from-loop 365 def loop_fn(i): 366 x1 = array_ops.gather(x, i) 367 y1 = array_ops.gather(y, i) 368 outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)] 369 del output_dtypes[:] 370 output_dtypes.extend(t.dtype for t in outputs) 371 return outputs 372 373 # pylint: enable=cell-var-from-loop 374 self._test_loop_fn(loop_fn, 3) 375 376 377@test_util.run_all_in_graph_and_eager_modes 378class ImageTest(PForTestCase): 379 380 def test_adjust_contrast(self): 381 images = random_ops.random_uniform([3, 2, 4, 4, 3]) 382 383 def loop_fn(i): 384 image = array_ops.gather(images, i) 385 return image_ops.adjust_contrast(image, 2.0) 386 387 self._test_loop_fn(loop_fn, 3) 388 389 def test_adjust_hue(self): 390 images = random_ops.random_uniform([3, 2, 4, 4, 3]) 391 392 def loop_fn(i): 393 image = array_ops.gather(images, i) 394 return image_ops.adjust_hue(image, .25) 395 396 self._test_loop_fn(loop_fn, 3) 397 398 def test_adjust_saturation(self): 399 images = random_ops.random_uniform([3, 2, 4, 4, 3]) 400 401 def loop_fn(i): 402 image = array_ops.gather(images, i) 403 return image_ops.adjust_saturation(image, 0.1) 404 405 self._test_loop_fn(loop_fn, 3) 406 407 408@test_util.run_all_in_graph_and_eager_modes 409class NNTest(PForTestCase): 410 411 def test_conv2d(self): 412 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 413 filt = random_ops.random_uniform([3, 3, 3, 7]) 414 415 def loop_fn(i): 416 x1 = array_ops.gather(x, i) 417 return nn.conv2d( 418 x1, filt, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC") 419 420 self._test_loop_fn(loop_fn, 3) 421 422 def test_conv2d_backprop_input(self): 423 x_shape = [2, 12, 12, 3] 424 filt = random_ops.random_uniform([3, 3, 3, 7]) 425 grad = random_ops.random_uniform([3, 2, 5, 5, 7]) 426 427 def loop_fn(i): 428 grad1 = array_ops.gather(grad, i) 429 return nn.conv2d_backprop_input( 430 x_shape, 431 filt, 432 grad1, 433 strides=[1, 2, 2, 1], 434 padding="VALID", 435 data_format="NHWC") 436 437 self._test_loop_fn(loop_fn, 3) 438 439 def test_conv2d_backprop_filter(self): 440 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 441 x_0 = array_ops.gather(x, 0) 442 filter_sizes = [3, 3, 3, 7] 443 grad = random_ops.random_uniform([3, 2, 5, 5, 7]) 444 445 def loop_fn(i): 446 x_i = array_ops.gather(x, i) 447 grad_i = array_ops.gather(grad, i) 448 return [ 449 nn.conv2d_backprop_filter( 450 inp, 451 filter_sizes, 452 grad_i, 453 strides=[1, 2, 2, 1], 454 padding="VALID", 455 data_format="NHWC") for inp in [x_i, x_0] 456 ] 457 458 self._test_loop_fn(loop_fn, 3) 459 460 def test_avg_pool(self): 461 with backprop.GradientTape(persistent=True) as g: 462 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 463 g.watch(x) 464 ksize = [1, 3, 3, 1] 465 466 def loop_fn(i): 467 with g: 468 x1 = array_ops.gather(x, i) 469 output = nn.avg_pool( 470 x1, 471 ksize, 472 strides=[1, 2, 2, 1], 473 padding="VALID", 474 data_format="NHWC") 475 loss = nn.l2_loss(output) 476 return output, g.gradient(loss, x1) 477 478 self._test_loop_fn(loop_fn, 3) 479 480 def test_avg_pool3d(self): 481 with backprop.GradientTape(persistent=True) as g: 482 x = random_ops.random_uniform([5, 3, 7, 6, 6, 5]) 483 g.watch(x) 484 ksize = [1, 2, 2, 2, 1] 485 strides = [1, 2, 2, 2, 1] 486 487 def loop_fn(i): 488 with g: 489 x1 = array_ops.gather(x, i) 490 output = nn.avg_pool3d( 491 x1, ksize, strides=strides, padding="VALID", data_format="NDHWC") 492 loss = nn.l2_loss(output) 493 return output, g.gradient(loss, x1) 494 495 self._test_loop_fn(loop_fn, 3) 496 497 def test_max_pool(self): 498 with backprop.GradientTape(persistent=True) as g: 499 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 500 g.watch(x) 501 ksize = [1, 3, 3, 1] 502 strides = [1, 2, 2, 1] 503 504 def loop_fn(i): 505 with g: 506 x1 = array_ops.gather(x, i) 507 output = nn.max_pool( 508 x1, ksize, strides=strides, padding="VALID", data_format="NHWC") 509 loss = nn.l2_loss(output) 510 ones = array_ops.ones_like(output) 511 g.watch(ones) 512 grad = g.gradient(loss, x1, output_gradients=ones) 513 grad_grad = g.gradient(grad, ones) 514 return output, grad, grad_grad 515 516 self._test_loop_fn(loop_fn, 3) 517 518 def test_max_pool_v2(self): 519 with backprop.GradientTape(persistent=True) as g: 520 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 521 g.watch(x) 522 ksize = [1, 3, 3, 1] 523 strides = [1, 2, 2, 1] 524 525 def loop_fn(i): 526 with g: 527 x1 = array_ops.gather(x, i) 528 output = gen_nn_ops.max_pool_v2( 529 x1, ksize, strides=strides, padding="VALID", data_format="NHWC") 530 loss = nn.l2_loss(output) 531 ones = array_ops.ones_like(output) 532 g.watch(ones) 533 grad = g.gradient(loss, x1, output_gradients=ones) 534 grad_grad = g.gradient(grad, ones) 535 return output, grad, grad_grad 536 537 self._test_loop_fn(loop_fn, 3) 538 539 def test_max_pool3d(self): 540 with backprop.GradientTape(persistent=True) as g: 541 x = random_ops.random_uniform([3, 3, 2, 12, 12, 3]) 542 g.watch(x) 543 ksize = [1, 1, 3, 3, 1] 544 strides = [1, 1, 2, 2, 1] 545 546 def loop_fn(i): 547 with g: 548 x1 = array_ops.gather(x, i) 549 output = nn.max_pool3d( 550 x1, ksize, strides=strides, padding="VALID", data_format="NDHWC") 551 loss = nn.l2_loss(output) 552 ones = array_ops.ones_like(output) 553 g.watch(ones) 554 grad = g.gradient(loss, x1, output_gradients=ones) 555 grad_grad = g.gradient(grad, ones) 556 return output, grad, grad_grad 557 558 self._test_loop_fn(loop_fn, 3) 559 560 def test_fused_batch_norm(self): 561 data_formats = ["NHWC"] 562 if test.is_gpu_available(): 563 data_formats.append("NCHW") 564 for is_training in (True, False): 565 for data_format in data_formats: 566 with backprop.GradientTape(persistent=True) as g: 567 if data_format == "NCHW": 568 x = random_ops.random_uniform([3, 1, 2, 5, 5]) 569 else: 570 x = random_ops.random_uniform([3, 1, 5, 5, 2]) 571 g.watch(x) 572 scale = random_ops.random_uniform([2]) 573 g.watch(scale) 574 offset = random_ops.random_uniform([2]) 575 g.watch(offset) 576 mean = None if is_training else random_ops.random_uniform([2]) 577 variance = None if is_training else random_ops.random_uniform([2]) 578 579 # pylint: disable=cell-var-from-loop 580 def loop_fn(i): 581 with g: 582 x1 = array_ops.gather(x, i) 583 outputs = nn.fused_batch_norm( 584 x1, 585 scale, 586 offset, 587 mean=mean, 588 variance=variance, 589 epsilon=0.01, 590 data_format=data_format, 591 is_training=is_training) 592 outputs = list(outputs) 593 # We only test the first value of outputs when is_training is 594 # False. It looks like CPU and GPU have different outputs for 595 # batch_mean and batch_variance for this case. 596 if not is_training: 597 outputs[1] = constant_op.constant(0.) 598 outputs[2] = constant_op.constant(0.) 599 loss = nn.l2_loss(outputs[0]) 600 if is_training: 601 gradients = g.gradient(loss, [x1, scale, offset]) 602 else: 603 gradients = [constant_op.constant(0.)] * 3 604 return outputs + gradients 605 606 # pylint: enable=cell-var-from-loop 607 608 self._test_loop_fn(loop_fn, 3) 609 610 def test_log_softmax(self): 611 logits = random_ops.random_uniform([3, 2, 4]) 612 613 def loop_fn(i): 614 logits_i = array_ops.gather(logits, i) 615 return (nn.log_softmax(logits_i), nn.log_softmax(logits_i, axis=0), 616 nn.log_softmax(logits_i, axis=-1)) 617 618 self._test_loop_fn(loop_fn, 3) 619 620 def test_softmax(self): 621 logits = random_ops.random_uniform([3, 2, 4]) 622 623 def loop_fn(i): 624 logits_i = array_ops.gather(logits, i) 625 return (nn.softmax(logits_i), nn.softmax(logits_i, axis=0), 626 nn.softmax(logits_i, axis=-1)) 627 628 self._test_loop_fn(loop_fn, 3) 629 630 def test_softmax_cross_entropy_with_logits(self): 631 with backprop.GradientTape(persistent=True) as g: 632 logits = random_ops.random_uniform([3, 2, 4]) 633 g.watch(logits) 634 labels = random_ops.random_uniform([3, 2, 4]) 635 labels /= math_ops.reduce_sum(labels, axis=[2], keepdims=True) 636 637 def loop_fn(i): 638 with g: 639 logits_i = array_ops.gather(logits, i) 640 labels_i = array_ops.gather(labels, i) 641 loss = nn.softmax_cross_entropy_with_logits( 642 labels=labels_i, logits=logits_i) 643 total_loss = math_ops.reduce_sum(loss) 644 return loss, g.gradient(total_loss, logits_i) 645 646 self._test_loop_fn(loop_fn, 3) 647 648 def test_sparse_softmax_cross_entropy_with_logits(self): 649 logits = random_ops.random_uniform([3, 2, 4]) 650 labels = random_ops.random_uniform( 651 shape=[3, 2], maxval=4, dtype=dtypes.int32) 652 653 def loop_fn(i): 654 logits_i = array_ops.gather(logits, i) 655 labels_i = array_ops.gather(labels, i) 656 loss = nn.sparse_softmax_cross_entropy_with_logits( 657 labels=labels_i, logits=logits_i) 658 return loss 659 660 self._test_loop_fn(loop_fn, 3) 661 662 663class RandomTest(PForTestCase): 664 665 # The random values generated in the two implementations are not guaranteed to 666 # match. So we only check the returned shapes. 667 def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): 668 outputs = self._run_targets(targets1, targets2) 669 n = len(outputs) // 2 670 for i in range(n): 671 self.assertAllEqual(outputs[i].shape, outputs[i + n].shape) 672 673 def test_random_uniform(self): 674 675 def loop_fn(_): 676 return random_ops.random_uniform([3]) 677 678 self._test_loop_fn(loop_fn, 5) 679 680 def test_random_uniform_int(self): 681 682 def loop_fn(_): 683 return random_ops.random_uniform([3], maxval=1, dtype=dtypes.int32) 684 685 self._test_loop_fn(loop_fn, 5) 686 687 def test_random_standard_normal(self): 688 689 def loop_fn(_): 690 return random_ops.random_normal([3]) 691 692 self._test_loop_fn(loop_fn, 5) 693 694 def test_truncated_normal(self): 695 696 def loop_fn(_): 697 return random_ops.truncated_normal([3]) 698 699 self._test_loop_fn(loop_fn, 5) 700 701 def test_random_gamma_invariant_alpha(self): 702 703 def loop_fn(_): 704 return random_ops.random_gamma([3], alpha=[0.5]) 705 706 self._test_loop_fn(loop_fn, 5) 707 708 def test_random_gamma_varying_alpha(self): 709 alphas = math_ops.exp(random_ops.random_normal([5, 3, 2])) 710 711 def loop_fn(i): 712 alphas_i = array_ops.gather(alphas, i) 713 # Test both scalar and non-scalar params and shapes. 714 return (random_ops.random_gamma(alpha=alphas_i[0, 0], shape=[]), 715 random_ops.random_gamma(alpha=alphas_i, shape=[]), 716 random_ops.random_gamma(alpha=alphas_i[0, 0], shape=[3]), 717 random_ops.random_gamma(alpha=alphas_i, shape=[3])) 718 719 self._test_loop_fn(loop_fn, 5) 720 721 def test_random_poisson_v2_invariant_rate(self): 722 723 def loop_fn(_): 724 return random_ops.random_poisson(lam=[1.3], shape=[3]) 725 726 self._test_loop_fn(loop_fn, 5) 727 728 def test_random_poisson_v2_varying_rate(self): 729 rates = math_ops.exp(random_ops.random_normal([5, 3, 2])) 730 731 def loop_fn(i): 732 rates_i = array_ops.gather(rates, i) 733 # Test both scalar and non-scalar params and shapes. 734 return (random_ops.random_poisson(lam=rates_i[0, 0], shape=[]), 735 random_ops.random_poisson(lam=rates_i, shape=[]), 736 random_ops.random_poisson(lam=rates_i[0, 0], shape=[3]), 737 random_ops.random_poisson(lam=rates_i, shape=[3])) 738 739 self._test_loop_fn(loop_fn, 5) 740 741 def test_random_multinomial_invariant_logits(self): 742 743 def loop_fn(_): 744 return random_ops.categorical(logits=[[1., -1.]], num_samples=3) 745 746 self._test_loop_fn(loop_fn, 5) 747 748 def test_random_multinomial_varying_logits(self): 749 logits = random_ops.random_normal([5, 3, 2]) 750 751 def loop_fn(i): 752 logits_i = array_ops.gather(logits, i) 753 return random_ops.categorical(logits_i, num_samples=3) 754 755 self._test_loop_fn(loop_fn, 5) 756 757 758class StatelessRandomTest(PForTestCase): 759 760 # This test currently only tests that the vectorized and non-vectorized 761 # outputs have same shapes. This is needed since under XLA compilation, 762 # stateless random numbers can generate different random numbers. 763 # TODO(agarwal): switch to checking for actual values matching once 764 # b/149402339 is resolved. 765 def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): 766 outputs = self._run_targets(targets1, targets2) 767 n = len(outputs) // 2 768 for i in range(n): 769 self.assertAllEqual(outputs[i].shape, outputs[i + n].shape) 770 771 # TODO(agarwal): add tests for other random functions 772 def test_multinomial(self): 773 seeds = [[1, 2], [3, 4]] 774 logits = random_ops.random_uniform([2, 3, 4]) 775 776 def loop_fn(i): 777 logits_0 = array_ops.gather(logits, 0) 778 logits_i = array_ops.gather(logits, i) 779 seeds_0 = array_ops.gather(seeds, 0) 780 seeds_i = array_ops.gather(seeds, i) 781 return (stateless_random_ops.stateless_categorical( 782 logits=logits_i, num_samples=3, seed=seeds_i), 783 stateless_random_ops.stateless_categorical( 784 logits=logits_i, num_samples=3, seed=seeds_0), 785 stateless_random_ops.stateless_categorical( 786 logits=logits_0, num_samples=3, seed=seeds_i), 787 stateless_random_ops.stateless_categorical( 788 logits=logits_0, num_samples=3, seed=seeds_0)) 789 790 self._test_loop_fn(loop_fn, 2) 791 792 793class LoggingTest(PForTestCase): 794 795 @test_util.run_v1_only("b/122612051") 796 def test_print(self): 797 x = random_ops.random_uniform([3, 5]) 798 799 def loop_fn(i): 800 x1 = array_ops.gather(x, i) 801 return logging_ops.Print( 802 x1, [x1, "x1", array_ops.shape(x1)], summarize=10) 803 804 self._test_loop_fn(loop_fn, 3) 805 806 def test_assert(self): 807 808 def loop_fn(i): 809 return control_flow_ops.Assert(i < 10, [i, [10], [i + 1]]) 810 811 # TODO(agarwal): make this work with for_loop. 812 with session.Session() as sess: 813 sess.run(pfor_control_flow_ops.pfor(loop_fn, 3)) 814 sess.run(pfor_control_flow_ops.pfor( 815 lambda i, pfor_config: loop_fn(i), 3)) 816 817 818class TensorArrayTest(PForTestCase): 819 820 def setUp(self): 821 self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() 822 control_flow_v2_toggles.disable_control_flow_v2() 823 super(TensorArrayTest, self).setUp() 824 825 def tearDown(self): 826 if self._enabled: 827 control_flow_v2_toggles.enable_control_flow_v2() 828 super(TensorArrayTest, self).tearDown() 829 830 @test_util.run_v1_only("b/122612051") 831 def test_create_outside_and_read(self): 832 833 ta = tensor_array_ops.TensorArray( 834 dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1) 835 836 def loop_fn(i): 837 return ta.read(i), ta.read(0) 838 839 self._test_loop_fn(loop_fn, 2) 840 841 @test_util.run_v1_only("b/122612051") 842 def test_create_outside_and_gather(self): 843 844 ta = tensor_array_ops.TensorArray( 845 dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1) 846 847 def loop_fn(i): 848 return ta.gather([i]), ta.gather([0, 1]) 849 850 self._test_loop_fn(loop_fn, 2) 851 852 @test_util.run_v1_only("b/122612051") 853 def test_create_outside_and_write_and_scatter(self): 854 855 t = tensor_array_ops.TensorArray(dtypes.int32, 10, clear_after_read=False) 856 handle = t.handle 857 858 def loop_fn(i): 859 ta = t.write(i + 2, 2 * i).write(i, 5) 860 ta = ta.scatter([4 + i], [4]).scatter([6 + i, 8 + i], [6 + i, 8 + i]) 861 return ta.flow 862 863 t1 = pfor_control_flow_ops.pfor(loop_fn, iters=2) 864 out1 = tensor_array_ops.TensorArray( 865 dtypes.int32, handle=handle, flow=t1[-1]).stack() 866 output1 = self._run_targets(out1) 867 868 t2 = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, iters=2) 869 out2 = tensor_array_ops.TensorArray( 870 dtypes.int32, handle=handle, flow=t2[-1]).stack() 871 output2 = self._run_targets(out2) 872 self.assertAllClose(output2, output1) 873 874 @test_util.run_v1_only("b/122612051") 875 def test_create_inside_and_write(self): 876 877 def loop_fn(i): 878 # TODO(agarwal): switching the order of writes to ta1 does not work. 879 ta1 = tensor_array_ops.TensorArray(dtypes.int32, 2).write(0, 880 i).write(1, 1) 881 ta2 = tensor_array_ops.TensorArray(dtypes.int32, 1).write(0, 1) 882 return ta1.stack(), ta2.stack() 883 884 self._test_loop_fn(loop_fn, 3) 885 886 @test_util.run_v1_only("b/122612051") 887 def test_create_inside_and_scatter(self): 888 889 def loop_fn(i): 890 # TODO(agarwal): switching the order of scatter to ta1 does not work. 891 ta1 = tensor_array_ops.TensorArray(dtypes.int32, 892 2).scatter([0], 893 [[i, 2]]).scatter([1], 894 [[1, 2]]) 895 ta2 = tensor_array_ops.TensorArray(dtypes.int32, 896 2).scatter([0], [3]).scatter([1], [4]) 897 return ta1.stack(), ta2.stack() 898 899 self._test_loop_fn(loop_fn, 3) 900 901 @test_util.run_v1_only("b/122612051") 902 def test_create_inside_and_read(self): 903 904 def loop_fn(i): 905 ta1 = tensor_array_ops.TensorArray( 906 dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1) 907 ta2 = tensor_array_ops.TensorArray( 908 dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2) 909 # TODO(agarwal): ta1.read(i) currently is not supported. 910 return ta1.read(0), ta2.read(0), ta2.read(i) 911 912 self._test_loop_fn(loop_fn, 2) 913 914 @test_util.run_v1_only("b/122612051") 915 def test_create_inside_and_gather(self): 916 917 def loop_fn(i): 918 ta1 = tensor_array_ops.TensorArray( 919 dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1) 920 ta2 = tensor_array_ops.TensorArray( 921 dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2) 922 # TODO(agarwal): ta1.read(i) currently is not supported. 923 return ta1.gather([0, 1]), ta2.gather([0, 1]), ta2.gather([i]) 924 925 self._test_loop_fn(loop_fn, 2) 926 927 @test_util.run_v1_only("b/122612051") 928 def test_grad(self): 929 x = random_ops.random_uniform([3, 2]) 930 ta = tensor_array_ops.TensorArray( 931 dtypes.float32, 3, clear_after_read=False).unstack(x) 932 y = math_ops.square(ta.stack()) 933 934 def loop_fn(i): 935 y_i = array_ops.gather(y, i) 936 grad = gradient_ops.gradients(y_i, x)[0] 937 return array_ops.gather(grad, i) 938 939 t1 = pfor_control_flow_ops.pfor(loop_fn, iters=3) 940 # y = x * x. Hence dy/dx = 2 * x. 941 actual_grad = 2.0 * x 942 with session.Session() as sess: 943 actual_grad, computed_grad = sess.run([t1, actual_grad]) 944 self.assertAllClose(actual_grad, computed_grad) 945 946 947@test_util.run_all_in_graph_and_eager_modes 948class TensorListTest(PForTestCase): 949 950 def test_create_outside_and_write(self): 951 handle1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 952 handle2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 953 954 def loop_fn(i): 955 h1 = list_ops.tensor_list_set_item(handle1, 0, i) 956 h1 = list_ops.tensor_list_set_item(h1, 1, 1) 957 h2 = list_ops.tensor_list_set_item(handle2, 0, 1) 958 return (list_ops.tensor_list_stack(h1, dtypes.int32), 959 list_ops.tensor_list_stack(h2, dtypes.int32)) 960 961 self._test_loop_fn(loop_fn, 3) 962 963 def test_create_inside_and_write(self): 964 965 def loop_fn(i): 966 h1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 967 h1 = list_ops.tensor_list_set_item(h1, 0, i) 968 h1 = list_ops.tensor_list_set_item(h1, 1, 1) 969 h2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 970 h2 = list_ops.tensor_list_set_item(h2, 0, 1) 971 return (list_ops.tensor_list_stack(h1, dtypes.int32), 972 list_ops.tensor_list_stack(h2, dtypes.int32)) 973 974 self._test_loop_fn(loop_fn, 3) 975 976 def test_create_outside_and_read(self): 977 handle = list_ops.tensor_list_reserve([], 2, dtypes.int32) 978 handle = list_ops.tensor_list_set_item(handle, 0, 0) 979 handle = list_ops.tensor_list_set_item(handle, 1, 1) 980 981 def loop_fn(i): 982 return (list_ops.tensor_list_get_item(handle, i, dtypes.int32), 983 list_ops.tensor_list_get_item(handle, 0, dtypes.int32), 984 list_ops.tensor_list_length(handle), 985 list_ops.tensor_list_element_shape(handle, dtypes.int32), 986 list_ops.tensor_list_element_shape(handle, dtypes.int64)) 987 988 self._test_loop_fn(loop_fn, 2) 989 990 @test_util.disable_tfrt("b/180206304") 991 def test_create_inside_and_read(self): 992 993 def loop_fn(i): 994 handle = list_ops.tensor_list_reserve([], 2, dtypes.int32) 995 handle = list_ops.tensor_list_set_item(handle, 0, i) 996 handle = list_ops.tensor_list_set_item(handle, 1, 1) 997 return (list_ops.tensor_list_get_item(handle, 0, dtypes.int32), 998 list_ops.tensor_list_get_item(handle, i, dtypes.int32), 999 list_ops.tensor_list_length(handle), 1000 list_ops.tensor_list_element_shape(handle, dtypes.int32), 1001 list_ops.tensor_list_element_shape(handle, dtypes.int64)) 1002 1003 self._test_loop_fn(loop_fn, 2) 1004 1005 def test_create_outside_and_push_back(self): 1006 h = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1007 1008 def loop_fn(i): 1009 handle = list_ops.tensor_list_push_back(h, [i, 2]) 1010 handle = list_ops.tensor_list_push_back(handle, [1, 2]) 1011 handle = list_ops.tensor_list_push_back(handle, [1, 2]) 1012 return list_ops.tensor_list_stack(handle, dtypes.int32) 1013 1014 self._test_loop_fn(loop_fn, 3) 1015 1016 def test_create_inside_and_push_back(self): 1017 1018 def loop_fn(i): 1019 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1020 handle = list_ops.tensor_list_push_back(handle, [i, 2]) 1021 handle = list_ops.tensor_list_push_back(handle, [1, 2]) 1022 return list_ops.tensor_list_stack(handle, dtypes.int32) 1023 1024 self._test_loop_fn(loop_fn, 3) 1025 1026 def test_pop_back_no_shape(self): 1027 1028 def loop_fn(i): 1029 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1030 handle = list_ops.tensor_list_push_back(handle, [1, 2]) 1031 handle = list_ops.tensor_list_push_back(handle, [i, 2]) 1032 handle, tensor = list_ops.tensor_list_pop_back(handle, dtypes.int32) 1033 return tensor, list_ops.tensor_list_stack(handle, dtypes.int32) 1034 1035 self._test_loop_fn(loop_fn, 3) 1036 1037 def test_pop_back_no_shape_capture(self): 1038 h = list_ops.tensor_list_reserve([2], 1, dtypes.int32) 1039 h = list_ops.tensor_list_push_back(h, [1, 2]) 1040 1041 def loop_fn(i): 1042 handle, tensor = list_ops.tensor_list_pop_back(h, dtypes.int32) 1043 handle = list_ops.tensor_list_push_back(handle, [1, i]) 1044 return tensor, list_ops.tensor_list_stack(handle, dtypes.int32) 1045 1046 self._test_loop_fn(loop_fn, 3) 1047 1048 def test_pop_back_with_shape(self): 1049 1050 @def_function.function 1051 def loop_fn(i): 1052 with backprop.GradientTape() as tape: 1053 handle = list_ops.tensor_list_reserve(None, 1, dtypes.float32) 1054 x = math_ops.cast(i, dtypes.float32)[None] 1055 tape.watch(x) 1056 handle = list_ops.tensor_list_push_back(handle, x) 1057 stacked = list_ops.tensor_list_stack(handle, dtypes.float32) 1058 list_grad = tape.gradient(stacked, x, x) 1059 self.assertEqual("TensorListPopBack", list_grad.op.type) 1060 return list_grad, stacked, list_grad.op.inputs[1] 1061 1062 self._test_loop_fn(loop_fn, 3) 1063 1064 def test_create_outside_and_scatter(self): 1065 h = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1066 1067 def loop_fn(i): 1068 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h) 1069 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1070 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1071 return list_ops.tensor_list_stack(handle, dtypes.int32) 1072 1073 self._test_loop_fn(loop_fn, 3) 1074 1075 def test_create_inside_and_scatter(self): 1076 1077 def loop_fn(i): 1078 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1079 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle) 1080 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1081 return list_ops.tensor_list_stack(handle, dtypes.int32) 1082 1083 self._test_loop_fn(loop_fn, 3) 1084 1085 def test_create_outside_and_gather(self): 1086 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1087 handle = list_ops.tensor_list_scatter([[2, 3]], [0], input_handle=handle) 1088 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1089 1090 def loop_fn(i): 1091 return (list_ops.tensor_list_gather(handle, [0, 1], dtypes.int32), 1092 list_ops.tensor_list_gather(handle, [i], dtypes.int32)) 1093 1094 self._test_loop_fn(loop_fn, 2) 1095 1096 def test_create_inside_and_gather(self): 1097 1098 def loop_fn(i): 1099 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1100 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle) 1101 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1102 return (list_ops.tensor_list_gather(handle, [0, 1], dtypes.int32), 1103 list_ops.tensor_list_gather(handle, [i], dtypes.int32)) 1104 1105 self._test_loop_fn(loop_fn, 2) 1106 1107 def test_create_inside_and_concat(self): 1108 1109 def loop_fn(i): 1110 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1111 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle) 1112 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1113 return gen_list_ops.tensor_list_concat_v2( 1114 handle, 1115 element_dtype=dtypes.int32, 1116 element_shape=[2], 1117 leading_dims=[]) 1118 1119 output = pfor_control_flow_ops.pfor(loop_fn, 2) 1120 self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0]) 1121 self.assertAllClose([[2, 2], [2, 2]], output[1]) 1122 1123 def test_create_outside_and_concat(self): 1124 h = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1125 1126 def loop_fn(i): 1127 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h) 1128 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1129 return gen_list_ops.tensor_list_concat_v2( 1130 handle, 1131 element_dtype=dtypes.int32, 1132 element_shape=[2], 1133 leading_dims=[]) 1134 1135 output = pfor_control_flow_ops.pfor(loop_fn, 2) 1136 self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0]) 1137 self.assertAllClose([[2, 2], [2, 2]], output[1]) 1138 1139 def test_tensor_list_from_tensor(self): 1140 t = random_ops.random_uniform([2, 3, 4]) 1141 1142 def loop_fn(i): 1143 handle = list_ops.tensor_list_from_tensor(array_ops.gather(t, i), [4]) 1144 return list_ops.tensor_list_stack(handle, t.dtype) 1145 1146 self._test_loop_fn(loop_fn, 2) 1147 1148 def test_tensor_list_reserve_while_loop(self): 1149 # Here a loop invariant TensorList is captured by a while_loop, which then 1150 # performs loop dependent operations on it, resulting in a loop variant 1151 # output. This forces stacking of the variant handle captured by the 1152 # while_loop. 1153 # We handle this particular case by forcing vectorization of 1154 # TensorListReserve operation. 1155 v2_enabled = control_flow_v2_toggles.control_flow_v2_enabled() 1156 control_flow_v2_toggles.enable_control_flow_v2() 1157 1158 def loop_fn(i): 1159 handle = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1160 _, out_handle = control_flow_ops.while_loop( 1161 lambda j, _: j < 2, lambda j, h: 1162 (j + 1, list_ops.tensor_list_set_item(h, j, i)), (0, handle)) 1163 return list_ops.tensor_list_stack(out_handle, dtypes.int32) 1164 1165 self._test_loop_fn(loop_fn, 2) 1166 if not v2_enabled: 1167 control_flow_v2_toggles.disable_control_flow_v2() 1168 1169 def test_tensor_list_addn_already_stacked(self): 1170 1171 def loop_fn(i): 1172 l1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1173 l1 = list_ops.tensor_list_set_item(l1, 0, i) 1174 l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1175 l2 = list_ops.tensor_list_set_item(l2, 1, i) 1176 return list_ops.tensor_list_stack(math_ops.add_n([l1, l2]), dtypes.int32) 1177 1178 self._test_loop_fn(loop_fn, 2) 1179 1180 def test_tensor_list_addn_stacking_required(self): 1181 l1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1182 l1 = list_ops.tensor_list_set_item(l1, 1, 1) 1183 1184 def loop_fn(i): 1185 l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1186 l2 = list_ops.tensor_list_set_item(l2, 1, i) 1187 return list_ops.tensor_list_stack( 1188 math_ops.add_n([l1, l2]), dtypes.int32) 1189 1190 self._test_loop_fn(loop_fn, 2) 1191 1192 1193class OptionalTest(PForTestCase): 1194 1195 def test_optional_from_value(self): 1196 1197 def loop_fn(i): 1198 o = gen_dataset_ops.optional_from_value( 1199 [i, i + 1, constant_op.constant(3)]) 1200 gen_dataset_ops.optional_none() 1201 return gen_dataset_ops.optional_get_value( 1202 o, [dtypes.int32, dtypes.int32, dtypes.int32], 1203 [[], [], []]) 1204 1205 self._test_loop_fn(loop_fn, 2) 1206 1207 1208class StackTest(PForTestCase): 1209 1210 @test_util.run_v1_only("b/122612051") 1211 def test_stack_inside_loop_invariant(self): 1212 1213 def loop_fn(_): 1214 s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) 1215 op1 = data_flow_ops.stack_push_v2(s, 1) 1216 with ops.control_dependencies([op1]): 1217 op2 = data_flow_ops.stack_push_v2(s, 2) 1218 with ops.control_dependencies([op2]): 1219 e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1220 with ops.control_dependencies([e2]): 1221 e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1222 return e1, e2 1223 1224 self._test_loop_fn(loop_fn, 2) 1225 1226 @test_util.run_v1_only("b/122612051") 1227 def test_stack_inside_push_loop_dependent(self): 1228 1229 def loop_fn(i): 1230 s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) 1231 op1 = data_flow_ops.stack_push_v2(s, i) 1232 with ops.control_dependencies([op1]): 1233 op2 = data_flow_ops.stack_push_v2(s, 2) 1234 with ops.control_dependencies([op2]): 1235 e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1236 with ops.control_dependencies([e2]): 1237 e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1238 return e1, e2 1239 1240 self._test_loop_fn(loop_fn, 2) 1241 1242 @test_util.run_v1_only("b/122612051") 1243 def test_stack_outside_pop(self): 1244 s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) 1245 op = data_flow_ops.stack_push_v2(s, 5) 1246 with ops.control_dependencies([op]): 1247 op = data_flow_ops.stack_push_v2(s, 6) 1248 with ops.control_dependencies([op]): 1249 op = data_flow_ops.stack_push_v2(s, 7) 1250 1251 def loop_fn(_): 1252 e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1253 with ops.control_dependencies([e1]): 1254 e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1255 return e1, e2 1256 1257 with ops.control_dependencies([op]): 1258 e1, e2 = pfor_control_flow_ops.pfor(loop_fn, iters=2) 1259 with ops.control_dependencies([e1, e2]): 1260 e3 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1261 v1, v2, v3 = self._run_targets([e1, e2, e3], run_init=False) 1262 self.assertAllEqual([7, 7], v1) 1263 self.assertAllEqual([6, 6], v2) 1264 self.assertAllEqual(5, v3) 1265 1266 @test_util.run_v1_only("b/122612051") 1267 def test_stack_outside_push(self): 1268 s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) 1269 1270 def loop_fn(_): 1271 return data_flow_ops.stack_push_v2(s, 7) 1272 1273 with self.assertRaisesRegex(ValueError, "StackPushV2 not allowed.*"): 1274 pfor_control_flow_ops.pfor(loop_fn, iters=2) 1275 1276 1277# TODO(agarwal): test nested while_loops. This currently requires converting a 1278# tf.cond. 1279class WhileV1Test(PForTestCase): 1280 1281 def setUp(self): 1282 self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() 1283 control_flow_v2_toggles.disable_control_flow_v2() 1284 super(WhileV1Test, self).setUp() 1285 1286 def tearDown(self): 1287 if self._enabled: 1288 control_flow_v2_toggles.enable_control_flow_v2() 1289 super(WhileV1Test, self).tearDown() 1290 1291 def test_while_outside_loop(self): 1292 1293 x = control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) 1294 1295 def loop_fn(i): 1296 return x + i 1297 1298 self._test_loop_fn(loop_fn, 3) 1299 1300 @test_util.run_v1_only("b/122612051") 1301 def test_invariant_while(self): 1302 1303 def loop_fn(_): 1304 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) 1305 1306 self._test_loop_fn(loop_fn, 3) 1307 1308 @test_util.run_v1_only("b/122612051") 1309 def test_invariant_while_with_control_dependency(self): 1310 1311 def loop_fn(i): 1312 with ops.control_dependencies([i]): 1313 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, 1314 [0]) 1315 1316 self._test_loop_fn(loop_fn, 3) 1317 1318 @test_util.run_v1_only("b/122612051") 1319 def test_while_with_stateful_ops(self): 1320 1321 def loop_fn(_): 1322 return control_flow_ops.while_loop( 1323 lambda j, x: j < 4, lambda j, x: 1324 (j + 1, x + random_ops.random_uniform([])), [0, 0.])[0] 1325 1326 self._test_loop_fn(loop_fn, 3) 1327 1328 @test_util.run_v1_only("b/122612051") 1329 def test_while_unstacked_condition(self): 1330 1331 def loop_fn(i): 1332 return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x: 1333 (j + 1, x + i), [0, 0]) 1334 1335 self._test_loop_fn(loop_fn, 3) 1336 1337 @test_util.run_v1_only("b/122612051") 1338 def test_while(self): 1339 x = random_ops.random_uniform([3, 5]) 1340 lengths = constant_op.constant([4, 0, 2]) 1341 1342 def loop_fn(i): 1343 x_i = array_ops.gather(x, i) 1344 lengths_i = array_ops.gather(lengths, i) 1345 1346 _, total = control_flow_ops.while_loop( 1347 lambda j, _: j < lengths_i, lambda j, t: 1348 (j + 1, t + array_ops.gather(x_i, j)), [0, 0.]) 1349 return total 1350 1351 self._test_loop_fn(loop_fn, 3) 1352 1353 @test_util.run_v1_only("b/122612051") 1354 def test_while_jacobian(self): 1355 x = random_ops.random_uniform([1, 3]) 1356 y = random_ops.random_uniform([3, 3]) 1357 1358 # out = x @ y @ y @ y @ y, where @ is matmul operator. 1359 _, out = control_flow_ops.while_loop( 1360 lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)), 1361 [0, x]) 1362 1363 def loop_fn(i): 1364 out_i = array_ops.gather(out, i, axis=1) 1365 return array_ops.reshape(gradient_ops.gradients(out_i, x)[0], [-1]) 1366 1367 out = pfor_control_flow_ops.pfor(loop_fn, iters=3) 1368 1369 # The above code does not work with tf.while_loop instead of pfor. So we 1370 # manually compute the expected output here. 1371 # Note that gradient of output w.r.t is (y @ y @ y @ y)^T. 1372 expected_output = y 1373 for _ in range(3): 1374 expected_output = math_ops.matmul(expected_output, y) 1375 expected_output = array_ops.transpose(expected_output, [1, 0]) 1376 1377 with session.Session() as sess: 1378 out, expected = sess.run([out, expected_output]) 1379 self.assertAllClose(expected, out) 1380 1381 @test_util.run_v1_only("b/122612051") 1382 def test_tensor_array_as_loop_variable(self): 1383 1384 def loop_fn(i): 1385 1386 def body(j, ta): 1387 ta = ta.write(j, i + j * j) 1388 return j + 1, ta 1389 1390 _, ta = control_flow_ops.while_loop( 1391 lambda j, _: j < 4, body, 1392 (0, tensor_array_ops.TensorArray(dtypes.int32, size=4))) 1393 return ta.stack() 1394 1395 self._test_loop_fn(loop_fn, 3) 1396 1397 @test_util.run_v1_only("b/122612051") 1398 def test_read_tensor_array_partitioned_indices(self): 1399 # Note that tensor array values are pfor loop dependent, and the while loop 1400 # termination condition is also dependent on pfor iteration. 1401 def loop_fn(i): 1402 ta = tensor_array_ops.TensorArray(dtypes.int32, size=6) 1403 ta = ta.unstack(i + list(range(5))) 1404 1405 def body(j, s): 1406 return j + 1, s + ta.read(j) 1407 1408 _, s = control_flow_ops.while_loop(lambda j, _: j < i, body, (0, 0)) 1409 return s 1410 1411 self._test_loop_fn(loop_fn, 3) 1412 1413 @test_util.run_v1_only("b/122612051") 1414 def test_external_while_loop_grad(self): 1415 # Here we test that external while_loops that are extended from inside pfor 1416 # (due to gradient calls) are not actually converted. If the below was 1417 # converted all pfor iterations would write to the same tensor array 1418 # indices. 1419 x = constant_op.constant(1.) 1420 1421 def body(j, ta): 1422 ta = ta.write(j, x) 1423 return j + 1, ta 1424 1425 _, ta = control_flow_ops.while_loop( 1426 lambda j, _: j < 4, body, 1427 (0, tensor_array_ops.TensorArray(dtypes.float32, size=4))) 1428 out = ta.stack() 1429 1430 def loop_fn(i): 1431 out_i = array_ops.gather(out, i) 1432 return gradient_ops.gradients(out_i, x)[0] 1433 1434 with session.Session() as sess: 1435 # out is [x, x, x]. Hence the gradients should be [1, 1, 1]. 1436 self.assertAllEqual([1, 1, 1], 1437 sess.run(pfor_control_flow_ops.pfor(loop_fn, 3))) 1438 1439 @test_util.run_v1_only("b/122612051") 1440 def test_tensor_array_grad(self): 1441 inp = constant_op.constant(np.random.rand(3, 4, 2), dtype=dtypes.float32) 1442 ta = tensor_array_ops.TensorArray(dtypes.float32, size=3) 1443 ta = ta.unstack(inp) 1444 1445 def loop_fn(i): 1446 1447 def body(j, x): 1448 value = ta.gather([j]) 1449 value = array_ops.gather(array_ops.reshape(value, [4, 2]), i) 1450 return j + 1, x + value 1451 1452 _, out = control_flow_ops.while_loop(lambda j, _: j < 3, body, 1453 (0, array_ops.zeros([2]))) 1454 out = math_ops.reduce_prod(out) 1455 return out, gradient_ops.gradients(out, inp)[0] 1456 1457 pfor_out, pfor_out_grad = pfor_control_flow_ops.pfor(loop_fn, 4) 1458 # Note that tf.while_loop does not work in the setup above. So we manually 1459 # construct the equivalent computation of the above loops here. 1460 real_out = math_ops.reduce_sum(inp, axis=[0]) 1461 real_out = math_ops.reduce_prod(real_out, axis=[1]) 1462 # Note that gradients of real_out will accumulate the gradients across the 1463 # output value. Hence we do the same aggregation on pfor_out_grad. 1464 real_out_grad = gradient_ops.gradients(real_out, inp)[0] 1465 sum_pfor_out_grad = math_ops.reduce_sum(pfor_out_grad, axis=[0]) 1466 1467 with session.Session() as sess: 1468 v1, v2, v1_grad, v2_grad = sess.run( 1469 [pfor_out, real_out, sum_pfor_out_grad, real_out_grad]) 1470 self.assertAllClose(v1, v2) 1471 self.assertAllClose(v1_grad, v2_grad) 1472 1473 1474def dynamic_lstm_input_fn(batch_size, state_size, max_steps): 1475 # We make inputs and sequence_length constant so that multiple session.run 1476 # calls produce the same result. 1477 inputs = constant_op.constant( 1478 np.random.rand(batch_size, max_steps, state_size), dtype=dtypes.float32) 1479 sequence_length = np.random.randint(0, size=[batch_size], high=max_steps + 1) 1480 sequence_length = constant_op.constant(sequence_length, dtype=dtypes.int32) 1481 return inputs, sequence_length 1482 1483 1484def create_dynamic_lstm(cell_fn, batch_size, state_size, max_steps): 1485 cell = cell_fn(state_size) 1486 inputs, sequence_length = dynamic_lstm_input_fn(batch_size, state_size, 1487 max_steps) 1488 inputs_ta = tensor_array_ops.TensorArray( 1489 dtypes.float32, size=max_steps, element_shape=[batch_size, state_size]) 1490 inputs_time_major = array_ops.transpose(inputs, [1, 0, 2]) 1491 inputs_ta = inputs_ta.unstack(inputs_time_major) 1492 zeros = array_ops.zeros([state_size]) 1493 1494 def loop_fn(i): 1495 sequence_length_i = array_ops.gather(sequence_length, i) 1496 1497 def body_fn(t, state, ta): 1498 inputs_t = array_ops.expand_dims( 1499 array_ops.gather(inputs_ta.read(t), i), 0) 1500 output, new_state = cell(inputs_t, state) 1501 output = array_ops.reshape(output, [-1]) 1502 # TODO(agarwal): one optimization that dynamic_rnn uses is to avoid the 1503 # array_ops.where when t < min(sequence_length). Doing that requires 1504 # supporting tf.cond pfor conversion. 1505 done = t >= sequence_length_i 1506 output = array_ops.where(done, zeros, output) 1507 ta = ta.write(t, output) 1508 new_state = [ 1509 array_ops.where(done, s, ns) 1510 for s, ns in zip(nest.flatten(state), nest.flatten(new_state)) 1511 ] 1512 new_state = nest.pack_sequence_as(state, new_state) 1513 return t + 1, new_state, ta 1514 1515 def condition_fn(t, _, unused): 1516 del unused 1517 return t < max_steps 1518 1519 initial_state = cell.zero_state(1, dtypes.float32) 1520 _, state, ta = control_flow_ops.while_loop(condition_fn, body_fn, [ 1521 0, initial_state, 1522 tensor_array_ops.TensorArray(dtypes.float32, max_steps) 1523 ]) 1524 1525 new_state = [array_ops.reshape(x, [-1]) for x in nest.flatten(state)] 1526 new_state = nest.pack_sequence_as(initial_state, new_state) 1527 return ta.stack(), new_state 1528 1529 pfor_output = pfor_control_flow_ops.pfor(loop_fn, batch_size) 1530 tf_output = rnn.dynamic_rnn( 1531 cell, 1532 inputs, 1533 sequence_length=sequence_length, 1534 initial_state=cell.zero_state(batch_size, dtypes.float32)) 1535 return pfor_output, tf_output 1536 1537 1538@test_util.run_all_in_graph_and_eager_modes 1539class WhileV2Test(PForTestCase): 1540 1541 def setUp(self): 1542 self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() 1543 control_flow_v2_toggles.enable_control_flow_v2() 1544 super(WhileV2Test, self).setUp() 1545 1546 def tearDown(self): 1547 if not self._enabled: 1548 control_flow_v2_toggles.disable_control_flow_v2() 1549 super(WhileV2Test, self).tearDown() 1550 1551 def test_while_outside_loop(self): 1552 1553 def _f(): 1554 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) 1555 1556 def loop_fn(i): 1557 return _f() + i 1558 1559 self._test_loop_fn(loop_fn, 3) 1560 1561 def test_invariant_while(self): 1562 1563 def loop_fn(_): 1564 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) 1565 1566 self._test_loop_fn(loop_fn, 3) 1567 1568 def test_invariant_while_with_control_dependency(self): 1569 1570 def loop_fn(i): 1571 with ops.control_dependencies([i]): 1572 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, 1573 [0]) 1574 1575 self._test_loop_fn(loop_fn, 3) 1576 1577 def test_while_with_stateful_ops(self): 1578 1579 def loop_fn(_): 1580 j, _ = control_flow_ops.while_loop( 1581 lambda j, x: j < 4, lambda j, x: 1582 (j + 1, x + random_ops.random_uniform([])), [0, 0.]) 1583 return j 1584 1585 self._test_loop_fn(loop_fn, 3) 1586 1587 def test_while_with_variable(self): 1588 v = resource_variable_ops.ResourceVariable(5.) 1589 1590 def loop_fn(_): 1591 _, output = control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x: 1592 (j + 1, x + v), [0, 0.]) 1593 return output 1594 1595 self._test_loop_fn(loop_fn, 3) 1596 1597 def test_while_unstacked_condition(self): 1598 1599 def loop_fn(i): 1600 return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x: 1601 (j + 1, x + i), [0, 0]) 1602 1603 self._test_loop_fn(loop_fn, 3) 1604 1605 def test_while(self): 1606 x = random_ops.random_uniform([3, 5]) 1607 lengths = constant_op.constant([4, 0, 2]) 1608 1609 def loop_fn(i): 1610 x_i = array_ops.gather(x, i) 1611 lengths_i = array_ops.gather(lengths, i) 1612 1613 return control_flow_ops.while_loop( 1614 lambda j, _: j < lengths_i, lambda j, t: 1615 (j + 1, t + array_ops.gather(x_i, j)), [0, 0.]) 1616 1617 self._test_loop_fn(loop_fn, 3) 1618 1619 def test_while_change_input_invariance(self): 1620 # This tests cases where a loop invariant input to while has loop dependent 1621 # operations applied to it inside the while body. 1622 # It also test inputs that are passed through. 1623 def loop_fn(i): 1624 return control_flow_ops.while_loop( 1625 lambda j, *_: j < i, lambda j, x, y, z, w: 1626 (j + 1, x + i, y + x, z, w), [ 1627 0, 1628 constant_op.constant(0), 1629 constant_op.constant(1), i, 1630 constant_op.constant(2) 1631 ]) 1632 1633 self._test_loop_fn(loop_fn, 3) 1634 1635 def test_while_shape_invariants(self): 1636 1637 def loop_fn(i): 1638 return control_flow_ops.while_loop( 1639 lambda j, *_: j < 4, 1640 lambda j, x, y: (j + 1, x + i, y + 1), 1641 [0, constant_op.constant([0, 1]), 1642 constant_op.constant([2, 3])], 1643 shape_invariants=[ 1644 None, 1645 tensor_shape.TensorShape([2]), 1646 tensor_shape.TensorShape([2]) 1647 ]) 1648 1649 self._test_loop_fn(loop_fn, 3) 1650 1651 def test_while_jacobian(self): 1652 # Note that we wrap the code below in a tf.function since we don't want the 1653 # while_loop call to be evaluated eagerly using a python loop. 1654 @def_function.function 1655 def _f(x, y, use_pfor): 1656 # out = x @ y @ y @ y @ y, where @ is matmul operator. 1657 _, out = control_flow_ops.while_loop( 1658 lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)), 1659 [0, x]) 1660 1661 def loop_fn(i): 1662 out_i = array_ops.gather(out, i, axis=1) 1663 grad = gradient_ops.gradients(out_i, x) 1664 return array_ops.reshape(grad[0], [-1]) 1665 1666 if use_pfor: 1667 return pfor_control_flow_ops.pfor(loop_fn, iters=3) 1668 else: 1669 return pfor_control_flow_ops.for_loop( 1670 loop_fn, iters=3, loop_fn_dtypes=out.dtype) 1671 1672 x = constant_op.constant(np.random.uniform(size=(1, 3))) 1673 y = constant_op.constant(np.random.uniform(size=(3, 3))) 1674 self.assertAllClose(_f(x, y, True), _f(x, y, False)) 1675 1676 def test_scan(self): 1677 np.random.seed(seed=42) 1678 data = np.random.randn(3).astype(np.float32) 1679 1680 def log_prob(x): 1681 return math_ops.reduce_sum(functional_ops.scan_v2( 1682 lambda _, yi: (x - yi)**2, 1683 elems=data, 1684 initializer=constant_op.constant(0.))) 1685 1686 x = variables.Variable(array_ops.ones([2])) 1687 self.evaluate(x.initializer) 1688 v_log_prob = lambda x: pfor_control_flow_ops.vectorized_map(log_prob, x) 1689 theoretical, numerical = gradient_checker_v2.compute_gradient( 1690 v_log_prob, (x,), delta=1e-3) 1691 self.assertAllClose(theoretical, numerical, rtol=1e-2) 1692 1693 1694@test_util.run_all_in_graph_and_eager_modes 1695class NestedControlFlowTest(PForTestCase): 1696 1697 def setUp(self): 1698 self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() 1699 control_flow_v2_toggles.enable_control_flow_v2() 1700 super(NestedControlFlowTest, self).setUp() 1701 1702 def tearDown(self): 1703 if not self._enabled: 1704 control_flow_v2_toggles.disable_control_flow_v2() 1705 super(NestedControlFlowTest, self).tearDown() 1706 1707 def _cond(self, f=None, split=0): 1708 if f is None: 1709 f = lambda x, y: (x, y) 1710 1711 def _f(x, y): 1712 return control_flow_ops.cond(y > split, lambda: f(x, y), lambda: 1713 (x + 1., y)) 1714 1715 return _f 1716 1717 def _while(self, f=None): 1718 if f is None: 1719 f = lambda x, y: (x, y) 1720 1721 def _f(x, y): 1722 return control_flow_ops.while_loop( 1723 lambda j, _: j < y, lambda j, t: 1724 (j + 1, t + array_ops.gather(f(x, y)[0], j)), [0, x])[1], y 1725 1726 return _f 1727 1728 def _test_helper(self, f): 1729 x = random_ops.random_uniform([5, 5]) 1730 y = constant_op.constant([4, -1, 2, -2, 2]) 1731 1732 def loop_fn(i): 1733 x_i = array_ops.gather(x, i) 1734 y_i = array_ops.gather(y, i) 1735 return f(x_i, y_i) 1736 1737 self._test_loop_fn(loop_fn, 5) 1738 1739 def test_cond_while(self): 1740 self._test_helper(self._cond(self._while())) 1741 1742 def test_while_cond(self): 1743 self._test_helper(self._while(self._cond())) 1744 1745 def test_while_while(self): 1746 self._test_helper(self._while(self._while())) 1747 1748 def test_cond_cond(self): 1749 self._test_helper(self._cond(self._cond())) 1750 1751 1752@test_util.run_all_in_graph_and_eager_modes 1753@test_util.with_control_flow_v2 1754class StatelessIfTest(PForTestCase): 1755 1756 def test_loop_variant_cond(self): 1757 x = [1, 2, 3, 4, 5.] 1758 y = 2.5 1759 1760 @def_function.function 1761 def loop_fn(i): 1762 x_i = array_ops.gather(x, i) 1763 # Note that the output has a combination of then and else branches being 1764 # loop variant / invariant. 1765 return cond_v2.cond_v2(x_i < y, lambda: (y - x_i, y, 1., 2.), lambda: 1766 (x_i - y, 0., y, 3.)) 1767 1768 self._test_loop_fn(loop_fn, iters=5) 1769 1770 def test_loop_invariant_cond(self): 1771 x = [1, 2, 3, 4, 5.] 1772 y = 0.5 1773 z = random_ops.random_uniform([]) 1774 1775 @def_function.function 1776 def loop_fn(i): 1777 x_i = array_ops.gather(x, i) 1778 # Note that the output has a combination of then and else branches being 1779 # loop variant / invariant. 1780 return cond_v2.cond_v2(z < y, lambda: (y - x_i, y, 1., 2.), lambda: 1781 (x_i - y, 0., y, 3.)) 1782 1783 self._test_loop_fn(loop_fn, iters=5) 1784 1785 def test_empty_branch(self): 1786 x = [1, 2, 3, 4, 5.] 1787 y = 6. 1788 1789 @def_function.function 1790 def loop_fn(i): 1791 x_i = array_ops.gather(x, i) 1792 return cond_v2.cond_v2( 1793 x_i < y, # Note that else branch is empty. 1794 lambda: (y - x_i, y, 1., 2.), 1795 lambda: (x_i - y, 0., y, 3.)) 1796 1797 self._test_loop_fn(loop_fn, iters=5) 1798 1799 1800@test_util.run_all_in_graph_and_eager_modes 1801@test_util.with_control_flow_v2 1802class IfTest(PForTestCase): 1803 1804 def test_read_var(self): 1805 self.skipTest("b/156438918") # Flaky 1806 1807 x = [1, 2, 3, 4, 5.] 1808 y = 2.5 1809 z = resource_variable_ops.ResourceVariable(5.) 1810 1811 @def_function.function 1812 def loop_fn(i): 1813 x_i = array_ops.gather(x, i) 1814 return cond_v2.cond_v2(x_i < y, lambda: z - x_i, lambda: z + x_i) 1815 1816 self._test_loop_fn(loop_fn, iters=5) 1817 1818 1819class RNNTest(PForTestCase): 1820 1821 @test_util.run_v1_only("b/122612051") 1822 def test_dynamic_rnn(self): 1823 pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, 3, 5, 1824 7) 1825 self.run_and_assert_equal(pfor_outputs, tf_outputs) 1826 1827 @test_util.run_v1_only("b/122612051") 1828 def test_dynamic_lstm(self): 1829 pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicLSTMCell, 3, 5, 1830 7) 1831 self.run_and_assert_equal(pfor_outputs, tf_outputs) 1832 1833 1834# TODO(agarwal): benchmark numbers on GPU for graphs based on while_loop 1835# conversion don't look good. Some of it seems like lot of copies between host 1836# and device. Optimize that. 1837class Benchmarks(test.Benchmark): 1838 1839 def _run(self, targets, iters, name=None): 1840 1841 def _done(t): 1842 # Note that we don't use tf.control_dependencies since that will not make 1843 # sure that the computation on GPU has actually finished. So we fetch the 1844 # first element of the output, and assume that this will not be called on 1845 # empty tensors. 1846 return array_ops.gather(array_ops.reshape(t, [-1]), 0) 1847 1848 targets = [_done(x) for x in nest.flatten(targets)] 1849 sess = session.Session() 1850 with sess: 1851 init = variables.global_variables_initializer() 1852 sess.run(init) 1853 run_fn = sess.make_callable(targets) 1854 run_fn() # Warm up 1855 begin = time.time() 1856 for _ in range(iters): 1857 run_fn() 1858 end = time.time() 1859 avg_time_ms = 1000 * (end - begin) / iters 1860 self.report_benchmark(iters=iters, wall_time=avg_time_ms, name=name) 1861 return avg_time_ms 1862 1863 def benchmark_sess_run_overhead(self): 1864 with ops.Graph().as_default(): 1865 x = constant_op.constant(1.0) 1866 self._run(x, 10000, name="session_run_overhead") 1867 1868 def benchmark_add(self): 1869 with ops.Graph().as_default(): 1870 n = 256 1871 params = 1000 1872 x = random_ops.random_normal([n, params]) 1873 y = random_ops.random_normal([n, params]) 1874 1875 def loop_fn(i): 1876 x_i = array_ops.gather(x, i) 1877 y_i = array_ops.gather(y, i) 1878 return x_i + y_i 1879 1880 pfor_outputs = pfor_control_flow_ops.pfor(loop_fn, n) 1881 while_outputs = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, n) 1882 manual = x + y 1883 1884 self._run(manual, 1000, name="manual_add") 1885 self._run(pfor_outputs, 1000, name="pfor_add") 1886 self._run(while_outputs, 100, name="while_add") 1887 1888 def benchmark_matmul(self): 1889 with ops.Graph().as_default(): 1890 n = 1024 1891 params = 1000 1892 x = random_ops.random_normal([n, params]) 1893 y = random_ops.random_normal([params, params]) 1894 1895 def loop_fn(i): 1896 x_i = array_ops.expand_dims(array_ops.gather(x, i), 0) 1897 return math_ops.matmul(x_i, y) 1898 1899 pfor_outputs = pfor_control_flow_ops.pfor(loop_fn, n) 1900 while_outputs = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, n) 1901 manual = math_ops.matmul(x, y) 1902 1903 self._run(manual, 1000, name="manual_matmul") 1904 self._run(pfor_outputs, 1000, name="pfor_matmul") 1905 self._run(while_outputs, 100, name="while_matmul") 1906 1907 def benchmark_map_fn(self): 1908 with ops.Graph().as_default(): 1909 b = 256 1910 params = 1000 1911 inp = random_ops.random_normal((b, params)) 1912 fn = lambda x: x * x 1913 1914 def pfor_map_fn(f, x): 1915 return pfor_control_flow_ops.pfor(lambda i: f(array_ops.gather(x, i)), 1916 array_ops.shape(x)[0]) 1917 1918 map_output = map_fn.map_fn(fn, inp) 1919 pfor_output = pfor_map_fn(fn, inp) 1920 1921 self._run(map_output, 100, name="tf_map_fn") 1922 self._run(pfor_output, 100, name="pfor_map_fn") 1923 1924 def benchmark_basic_while(self): 1925 with ops.Graph().as_default(): 1926 1927 def loop_fn(i): 1928 _, s = control_flow_ops.while_loop(lambda t, x: t < i, lambda t, x: 1929 (t + 1, x + i), [0, 0]) 1930 return s 1931 1932 iters = 50 1933 pfor_output = pfor_control_flow_ops.pfor(loop_fn, iters) 1934 for_loop_output = pfor_control_flow_ops.for_loop(loop_fn, dtypes.int32, 1935 iters) 1936 self._run(pfor_output, 100, name="pfor_basic") 1937 self._run(for_loop_output, 100, name="for_loop_basic") 1938 1939 def benchmark_dynamic_rnn(self): 1940 with ops.Graph().as_default(): 1941 pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, 128, 1942 512, 16) 1943 self._run(pfor_outputs, 100, name="pfor_rnn") 1944 self._run(tf_outputs, 100, name="tf_rnn") 1945 1946 def benchmark_reduction(self): 1947 n = 1024 1948 with ops.Graph().as_default(): 1949 x = random_ops.random_uniform([n, n]) 1950 w = random_ops.random_uniform([n, n]) 1951 1952 def loop_fn(i, pfor_config): 1953 x_i = array_ops.gather(x, i) 1954 return math_ops.reduce_sum( 1955 math_ops.matmul(pfor_config.reduce_concat(x_i), w)) 1956 1957 # Note that output_reduction will be tiled, so there may be some minor 1958 # overheads compared to output_no_reduction. 1959 output_reduction = pfor_control_flow_ops.pfor(loop_fn, n) 1960 output_no_reduction = math_ops.reduce_sum(math_ops.matmul(x, w)) 1961 # Benchmark to test that reduction does not add overhead and its output is 1962 # treated as loop invariant. 1963 self._run(output_reduction, 30, name="matmul_reduction") 1964 self._run(output_no_reduction, 30, name="matmul_no_reduction") 1965 1966 1967class SparseTest(PForTestCase): 1968 1969 @test_util.run_v1_only("b/122612051") 1970 def test_var_loop_len(self): 1971 num_iters = array_ops.placeholder(dtypes.int32) 1972 1973 def loop_fn(_): 1974 return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6], 1975 [3]) # [0, 2, 0] 1976 1977 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 1978 with self.cached_session() as sess: 1979 sess.run(pfor, feed_dict={num_iters: 3}) 1980 1981 @test_util.run_v1_only("b/122612051") 1982 def test_sparse_result_none_stacked(self): 1983 num_iters = 10 1984 1985 def loop_fn(_): 1986 return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6], 1987 [3]) # [0, 2, 0] 1988 1989 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 1990 1991 indices = [[i, j] for i in range(num_iters) for j in range(3)] 1992 values = [4, 5, 6] * num_iters 1993 dense_shapes = [num_iters, 3] 1994 # Expected result: [[4, 5, 6], [4, 5, 6], [4, 5, 6], ...] 1995 manual = sparse_tensor.SparseTensor(indices, values, dense_shapes) 1996 self.run_and_assert_equal(pfor, manual) 1997 1998 @test_util.run_v1_only("b/122612051") 1999 def test_sparse_result_all_stacked(self): 2000 num_iters = 10 2001 2002 def loop_fn(i): 2003 i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) 2004 indices = array_ops.expand_dims(i, 0) 2005 return sparse_tensor.SparseTensor(indices, i, i + 1) # [0, ..., 0, i] 2006 2007 # Expected result: [[0], [0, 1], [0, 0, 2], [0, 0, 0, 3], ...] 2008 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2009 manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)], 2010 list(range(num_iters)), 2011 (num_iters, num_iters)) 2012 self.run_and_assert_equal(pfor, manual) 2013 2014 @test_util.run_v1_only("b/122612051") 2015 def test_sparse_result_indices_stacked(self): 2016 num_iters = 10 2017 2018 def loop_fn(i): 2019 i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) 2020 indices = array_ops.expand_dims(i, 0) 2021 return sparse_tensor.SparseTensor(indices, [1], [num_iters]) 2022 2023 # Expected result: identity matrix size num_iters * num_iters 2024 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2025 manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)], 2026 [1] * num_iters, (num_iters, num_iters)) 2027 self.run_and_assert_equal(pfor, manual) 2028 2029 @test_util.run_v1_only("b/122612051") 2030 def test_sparse_result_values_stacked(self): 2031 num_iters = 10 2032 2033 def loop_fn(i): 2034 i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) 2035 return sparse_tensor.SparseTensor([[0]], i, [num_iters]) # [i, 0, ..., 0] 2036 2037 # Expected result: [[1, 0, ...], [2, 0, ...], [3, 0, ...], ...] 2038 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2039 manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)], 2040 list(range(num_iters)), 2041 (num_iters, num_iters)) 2042 self.run_and_assert_equal(pfor, manual) 2043 2044 @test_util.run_v1_only("b/122612051") 2045 def test_sparse_result_shapes_stacked(self): 2046 num_iters = 10 2047 2048 def loop_fn(i): 2049 i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) 2050 return sparse_tensor.SparseTensor([[0]], [1], i + 1) # [1, 0, ..., 0] 2051 2052 # Expected result: [[1, 0, 0, ...], [1, 0, 0, ...], ...] 2053 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2054 manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)], 2055 [1] * num_iters, (num_iters, num_iters)) 2056 self.run_and_assert_equal(pfor, manual) 2057 2058 @test_util.run_v1_only("b/122612051") 2059 def test_sparse_result_shapes_stacked_2D(self): 2060 num_iters = 10 2061 2062 def loop_fn(i): 2063 i = array_ops.expand_dims(math_ops.cast(i + 1, dtypes.int64), 0) 2064 shape = array_ops.concat([i, i], 0) 2065 return sparse_tensor.SparseTensor([[0, 0]], [1], shape) # [1, 0, ..., 0] 2066 2067 # Expected result: [[[1, 0, ...], [0, ..., 0], [0, ..., 0], ...], ...] 2068 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2069 manual = sparse_tensor.SparseTensor([[i, 0, 0] for i in range(num_iters)], 2070 [1] * num_iters, 2071 (num_iters, num_iters, num_iters)) 2072 self.run_and_assert_equal(pfor, manual) 2073 2074 2075# Dummy CompositeTensor to test CompositeTensor support. 2076class Particle(composite_tensor.CompositeTensor): 2077 """A (batch of) particles each defined by a mass and a scalar velocity.""" 2078 2079 def __init__(self, mass, velocity): 2080 mass = ops.convert_to_tensor(mass) 2081 velocity = ops.convert_to_tensor(velocity) 2082 self.shape = array_ops.broadcast_static_shape(mass.shape, velocity.shape) 2083 self.mass = mass 2084 self.velocity = velocity 2085 2086 @property 2087 def _type_spec(self): 2088 return ParticleSpec( 2089 type_spec.type_spec_from_value(self.mass), 2090 type_spec.type_spec_from_value(self.velocity)) 2091 2092 2093class ParticleSpec(type_spec.BatchableTypeSpec): 2094 2095 def __init__(self, mass, velocity): 2096 self.shape = array_ops.broadcast_static_shape( 2097 mass.shape, velocity.shape) 2098 self.mass = mass 2099 self.velocity = velocity 2100 2101 def _serialize(self): 2102 return (self.mass, self.velocity) 2103 2104 @property 2105 def value_type(self): 2106 return Particle 2107 2108 @property 2109 def _component_specs(self): 2110 return (self.mass, self.velocity) 2111 2112 def _to_components(self, value): 2113 return (value.mass, value.velocity) 2114 2115 def _from_components(self, components): 2116 return Particle(*components) 2117 2118 def _pad_shape_to_full_rank(self, s): 2119 """Pad component shapes with 1's so all components have the same rank.""" 2120 return tensor_shape.TensorShape( 2121 [1] * (self.shape.ndims - s.ndims)).concatenate(s) 2122 2123 def _batch(self, batch_size): 2124 return ParticleSpec( 2125 mass=tensor_spec.TensorSpec( 2126 dtype=self.mass.dtype, 2127 shape=tensor_shape.TensorShape([batch_size]).concatenate( 2128 self._pad_shape_to_full_rank(self.mass.shape))), 2129 velocity=tensor_spec.TensorSpec( 2130 dtype=self.velocity.dtype, 2131 shape=tensor_shape.TensorShape([batch_size]).concatenate( 2132 self._pad_shape_to_full_rank(self.velocity.shape)))) 2133 2134 def _unbatch(self): 2135 return ParticleSpec( 2136 tensor_spec.TensorSpec(dtype=self.mass.dtype, 2137 shape=self.mass.shape[1:]), 2138 tensor_spec.TensorSpec(dtype=self.velocity.dtype, 2139 shape=self.velocity.shape[1:])) 2140 2141 def _to_tensor_list(self, value): 2142 return [array_ops.reshape( 2143 value.mass, 2144 self._pad_shape_to_full_rank(value.mass.shape)), 2145 array_ops.reshape( 2146 value.velocity, 2147 self._pad_shape_to_full_rank(value.velocity.shape))] 2148 2149 2150class CompositeTensorTest(PForTestCase, parameterized.TestCase): 2151 2152 @parameterized.parameters((None,), (3,)) 2153 def test_create_composite_inside_loop(self, parallel_iterations): 2154 num_particles = 10 2155 velocities = random_ops.random_uniform([num_particles]) 2156 particles = pfor_control_flow_ops.pfor( 2157 # Build a batch of particles all with the same mass. 2158 lambda i: Particle(mass=4., velocity=array_ops.gather(velocities, i)), 2159 num_particles, 2160 parallel_iterations=parallel_iterations) 2161 particles_mass, particles_velocity, velocities = self.evaluate( 2162 (particles.mass, particles.velocity, velocities)) 2163 self.assertAllEqual(particles_mass, 4. * np.ones([num_particles])) 2164 self.assertAllEqual(particles_velocity, velocities) 2165 2166 @parameterized.parameters((None,), (3,)) 2167 def test_composite_is_converted_to_batched_tensor( 2168 self, parallel_iterations): 2169 particles = pfor_control_flow_ops.pfor( 2170 lambda _: Particle(mass=random_ops.random_uniform([3]), # pylint: disable=g-long-lambda 2171 velocity=random_ops.random_uniform([5, 3])), 2172 4, 2173 parallel_iterations=parallel_iterations) 2174 # Naively batching the component shapes would give `[4, 3]` and `[4, 5, 3]` 2175 # which have no consistent broadcast shape. 2176 self.assertTrue(particles.mass.shape, [4, 1, 3]) 2177 self.assertAllEqual(particles.velocity.shape, [4, 5, 3]) 2178 2179 def test_vectorized_map_gathers_composite_tensors(self): 2180 particles = Particle(mass=[1., 2., 3., 4., 5.], 2181 velocity=[1., 2., 3., 4., 5.]) 2182 self.assertAllEqual( 2183 pfor_control_flow_ops.vectorized_map( 2184 lambda x: x.mass * x.velocity, particles), 2185 particles.mass * particles.velocity) 2186 2187 def test_vectorized_map_of_ragged_tensors(self): 2188 # Vmap should be able to handle ragged Tensors as long as they're not 2189 # *actually* ragged. 2190 ragged = ragged_tensor.RaggedTensor.from_uniform_row_length( 2191 ragged_tensor.RaggedTensor.from_row_lengths( 2192 values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 2193 row_lengths=[3, 3, 3, 3]), 2194 uniform_row_length=2) # Overall shape [2, 2, 3]. 2195 self.assertAllEqual( 2196 pfor_control_flow_ops.vectorized_map( 2197 lambda x: x.to_tensor(shape=[2, 3]), ragged), 2198 ragged.to_tensor(shape=[2, 2, 3])) 2199 2200 2201class ParsingTest(PForTestCase): 2202 2203 def test_decode_csv(self): 2204 csv_tensor = constant_op.constant([["1:2:3"], ["::"], ["7:8:9"]]) 2205 kwargs = {"record_defaults": [[10], [20], [30]], "field_delim": ":"} 2206 2207 def loop_fn(i): 2208 line = array_ops.gather(csv_tensor, i) 2209 return parsing_ops.decode_csv(line, **kwargs) 2210 2211 self._test_loop_fn(loop_fn, iters=3) 2212 2213 @test_util.run_v1_only("b/122612051") 2214 def test_parse_single_example(self): 2215 2216 def _int64_feature(*values): 2217 return feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=values)) 2218 2219 def _bytes_feature(*values): 2220 return feature_pb2.Feature( 2221 bytes_list=feature_pb2.BytesList( 2222 value=[v.encode("utf-8") for v in values])) 2223 2224 examples = constant_op.constant([ 2225 example_pb2.Example( 2226 features=feature_pb2.Features( 2227 feature={ 2228 "dense_int": _int64_feature(i), 2229 "dense_str": _bytes_feature(str(i)), 2230 "sparse_int": _int64_feature(i, i * 2, i * 4, i * 8), 2231 "sparse_str": _bytes_feature(*["abc"] * i) 2232 })).SerializeToString() for i in range(10) 2233 ]) 2234 2235 features = { 2236 "dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0), 2237 "dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""), 2238 "sparse_int": parsing_ops.VarLenFeature(dtypes.int64), 2239 "sparse_str": parsing_ops.VarLenFeature(dtypes.string), 2240 } 2241 2242 def loop_fn(i): 2243 example_proto = array_ops.gather(examples, i) 2244 f = parsing_ops.parse_single_example(example_proto, features) 2245 return f 2246 2247 pfor = pfor_control_flow_ops.pfor(loop_fn, iters=10) 2248 manual = parsing_ops.parse_example(examples, features) 2249 self.run_and_assert_equal(pfor, manual) 2250 2251 2252class PartitionedCallTest(PForTestCase): 2253 2254 def test_simple(self): 2255 2256 @def_function.function 2257 def f(x): 2258 return math_ops.square(x) + 1 2259 2260 z = random_ops.random_uniform([4]) 2261 2262 def loop_fn(i): 2263 return f(array_ops.gather(z, i)) 2264 2265 self._test_loop_fn(loop_fn, 4) 2266 2267 def test_nested_calls(self): 2268 2269 @def_function.function 2270 def inner(x): 2271 return math_ops.square(x) 2272 2273 @def_function.function 2274 def outer(y): 2275 return math_ops.reduce_sum(inner(y)) + 2 2276 2277 z = random_ops.random_uniform([4, 2]) 2278 2279 def loop_fn(i): 2280 return outer(array_ops.gather(z, i)) 2281 2282 self._test_loop_fn(loop_fn, 4) 2283 2284 def test_nested_definition(self): 2285 2286 @def_function.function 2287 def outer(y): 2288 2289 @def_function.function 2290 def inner(x): 2291 return math_ops.square(x) + 1 2292 2293 return math_ops.reduce_sum(inner(y)) + 2 2294 2295 z = random_ops.random_uniform([4, 2]) 2296 2297 def loop_fn(i): 2298 return outer(array_ops.gather(z, i)) 2299 2300 self._test_loop_fn(loop_fn, 4) 2301 2302 def test_gradients(self): 2303 2304 @def_function.function 2305 def f(x): 2306 return math_ops.square(x) + 1 2307 2308 z = random_ops.random_uniform([4, 2]) 2309 2310 def loop_fn(i): 2311 z_i = array_ops.gather(z, i) 2312 with backprop.GradientTape() as g: 2313 g.watch(z_i) 2314 out = f(z_i) 2315 return out, g.gradient(out, z_i) 2316 2317 self._test_loop_fn(loop_fn, 4) 2318 2319 def test_stateful_with_gradients(self): 2320 2321 z = random_ops.random_uniform([4, 2]) 2322 v = variables.Variable(z[0]) 2323 2324 @def_function.function 2325 def f(x): 2326 return math_ops.square(x) + v + 1 2327 2328 def loop_fn(i): 2329 z_i = array_ops.gather(z, i) 2330 with backprop.GradientTape() as g: 2331 g.watch(z_i) 2332 out = f(z_i) 2333 return out, g.gradient(out, z_i) 2334 2335 self._test_loop_fn(loop_fn, 4) 2336 2337 2338class SpectralTest(PForTestCase, parameterized.TestCase): 2339 2340 @parameterized.parameters( 2341 (fft_ops.fft,), 2342 (fft_ops.fft2d,), 2343 (fft_ops.fft3d,), 2344 (fft_ops.ifft,), 2345 (fft_ops.ifft2d,), 2346 (fft_ops.ifft3d,), 2347 ) 2348 def test_fft(self, op_func): 2349 shape = [2, 3, 4, 3, 4] 2350 x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape) 2351 2352 def loop_fn(i): 2353 x_i = array_ops.gather(x, i) 2354 return op_func(x_i) 2355 2356 self._test_loop_fn(loop_fn, 2) 2357 2358 @parameterized.parameters( 2359 (fft_ops.rfft,), 2360 (fft_ops.rfft2d,), 2361 (fft_ops.rfft3d,), 2362 ) 2363 def test_rfft(self, op_func): 2364 for dtype in (dtypes.float32, dtypes.float64): 2365 x = random_ops.random_uniform([2, 3, 4, 3, 4], dtype=dtype) 2366 2367 # pylint: disable=cell-var-from-loop 2368 def loop_fn(i): 2369 x_i = array_ops.gather(x, i) 2370 return op_func(x_i) 2371 2372 # pylint: enable=cell-var-from-loop 2373 2374 self._test_loop_fn(loop_fn, 2) 2375 2376 @parameterized.parameters( 2377 (fft_ops.irfft,), 2378 (fft_ops.irfft2d,), 2379 (fft_ops.irfft3d,), 2380 ) 2381 def test_irfft(self, op_func): 2382 if config.list_physical_devices("GPU"): 2383 # TODO(b/149957923): The test is flaky 2384 self.skipTest("b/149957923: irfft vectorization flaky") 2385 for dtype in (dtypes.complex64, dtypes.complex128): 2386 shape = [2, 3, 4, 3, 4] 2387 x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape) 2388 x = math_ops.cast(x, dtype=dtype) 2389 2390 # pylint: disable=cell-var-from-loop 2391 def loop_fn(i): 2392 x_i = array_ops.gather(x, i) 2393 return op_func(x_i) 2394 2395 # pylint: enable=cell-var-from-loop 2396 2397 self._test_loop_fn(loop_fn, 2) 2398 2399 2400class VariableTest(PForTestCase): 2401 2402 def test_create_variable_once(self): 2403 x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32) 2404 y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32) 2405 a_var = [] 2406 2407 def f(z): 2408 if not a_var: 2409 a_var.append(variables.Variable(lambda: y, name="a")) 2410 return math_ops.matmul(z, a_var[0] / 16) 2411 2412 pfor_control_flow_ops.vectorized_map(f, x) 2413 2414 @test_util.run_v2_only 2415 def test_create_variable_repeated(self): 2416 x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32) 2417 y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32) 2418 2419 def f(z): 2420 a_var = variables.Variable(lambda: y, name="a") / 4 2421 return math_ops.matmul(z, a_var / 16) 2422 2423 # Note that this error is only raised under v2 behavior. 2424 with self.assertRaisesRegex( 2425 ValueError, 2426 "tf.function-decorated function tried to create variables on non-first" 2427 ): 2428 pfor_control_flow_ops.vectorized_map(f, x) 2429 2430 @test_util.run_all_in_graph_and_eager_modes 2431 def test_variable_shape(self): 2432 v = resource_variable_ops.ResourceVariable([1, 2]) 2433 2434 def loop_fn(_): 2435 return resource_variable_ops.variable_shape(v.handle) 2436 2437 self._test_loop_fn(loop_fn, 2) 2438 2439 2440if __name__ == "__main__": 2441 test.main() 2442