1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for pfor and for_loop.""" 16# pylint: disable=g-direct-tensorflow-import 17 18import functools 19import sys 20import time 21 22from absl.testing import parameterized 23import numpy as np 24 25from google.protobuf import text_format 26 27from tensorflow.core.example import example_pb2 28from tensorflow.core.example import feature_pb2 29from tensorflow.core.framework import graph_pb2 30from tensorflow.python.client import session 31from tensorflow.python.eager import backprop 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.framework import composite_tensor 35from tensorflow.python.framework import config 36from tensorflow.python.framework import constant_op 37from tensorflow.python.framework import dtypes 38from tensorflow.python.framework import importer 39from tensorflow.python.framework import indexed_slices 40from tensorflow.python.framework import ops 41from tensorflow.python.framework import sparse_tensor 42from tensorflow.python.framework import tensor_shape 43from tensorflow.python.framework import tensor_spec 44from tensorflow.python.framework import test_util 45from tensorflow.python.framework import type_spec 46from tensorflow.python.ops import array_ops 47from tensorflow.python.ops import bitwise_ops 48from tensorflow.python.ops import cond_v2 49from tensorflow.python.ops import control_flow_ops 50from tensorflow.python.ops import control_flow_v2_toggles 51from tensorflow.python.ops import data_flow_ops 52from tensorflow.python.ops import functional_ops 53from tensorflow.python.ops import gen_dataset_ops 54from tensorflow.python.ops import gen_list_ops 55from tensorflow.python.ops import gen_nn_ops 56from tensorflow.python.ops import gradient_checker_v2 57from tensorflow.python.ops import gradients as gradient_ops 58from tensorflow.python.ops import image_ops 59from tensorflow.python.ops import list_ops 60from tensorflow.python.ops import logging_ops 61from tensorflow.python.ops import manip_ops 62from tensorflow.python.ops import map_fn 63from tensorflow.python.ops import math_ops 64from tensorflow.python.ops import nn 65from tensorflow.python.ops import parsing_ops 66from tensorflow.python.ops import random_ops 67from tensorflow.python.ops import resource_variable_ops 68from tensorflow.python.ops import rnn 69from tensorflow.python.ops import rnn_cell 70from tensorflow.python.ops import stateless_random_ops 71from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import 72from tensorflow.python.ops import tensor_array_ops 73from tensorflow.python.ops import variables 74from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops 75from tensorflow.python.ops.parallel_for.test_util import PForTestCase 76from tensorflow.python.ops.ragged import ragged_tensor 77from tensorflow.python.ops.signal import fft_ops 78from tensorflow.python.platform import test 79from tensorflow.python.util import nest 80 81 82@test_util.run_all_in_graph_and_eager_modes 83@test_util.with_control_flow_v2 84class PForTest(PForTestCase): 85 86 def test_op_conversion_fallback_to_while_loop(self): 87 # Note that we used top_k op for this test. If a converter gets defined for 88 # it, we will need to find another op for which a converter has not been 89 # defined. 90 x = random_ops.random_uniform([3, 2, 4]) 91 92 def loop_fn(i): 93 x_i = array_ops.gather(x, i) 94 return nn.top_k(x_i) 95 96 with self.assertRaisesRegex(ValueError, "No pfor vectorization"): 97 self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False) 98 self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True) 99 100 def test_parallel_iterations(self): 101 for parallel_iterations in [2, 3, 8, 10]: 102 x = random_ops.random_uniform([8, 3]) 103 104 # pylint: disable=cell-var-from-loop 105 def loop_fn(i): 106 return array_ops.gather(x, i) 107 108 # pylint: enable=cell-var-from-loop 109 110 self._test_loop_fn(loop_fn, 8, parallel_iterations=parallel_iterations) 111 self._test_loop_fn( 112 loop_fn, 113 4 * constant_op.constant(2), 114 parallel_iterations=parallel_iterations) 115 116 def test_parallel_iterations_preserves_static_shape(self): 117 for parallel_iterations in [2, 3, 8, 10]: 118 x = pfor_control_flow_ops.pfor( 119 lambda _: random_ops.random_uniform([2, 3]), 120 8, 121 parallel_iterations=parallel_iterations) 122 self.assertAllEqual(x.shape, [8, 2, 3]) 123 124 def test_parallel_iterations_zero(self): 125 with self.assertRaisesRegex(ValueError, "positive integer"): 126 pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=0) 127 with self.assertRaisesRegex(TypeError, "positive integer"): 128 pfor_control_flow_ops.for_loop( 129 lambda i: 1, dtypes.int32, 8, parallel_iterations=0) 130 131 def test_parallel_iterations_one(self): 132 with self.assertRaisesRegex(ValueError, "Use `for_loop` instead"): 133 pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=1) 134 135 def test_vectorized_map(self): 136 137 def compute(x): 138 return math_ops.reduce_mean(x, axis=0, keepdims=True) 139 140 result = pfor_control_flow_ops.vectorized_map(compute, 141 array_ops.ones((10, 5, 3))) 142 self.run_and_assert_equal(result, array_ops.ones((10, 1, 3))) 143 144 def test_vectorized_map_with_dynamic_shape(self): 145 146 def compute(x): 147 return math_ops.reduce_mean(x, axis=0, keepdims=True) 148 149 x = array_ops.placeholder_with_default( 150 array_ops.ones((10, 5, 3)), shape=None) 151 result = pfor_control_flow_ops.vectorized_map(compute, x) 152 self.run_and_assert_equal(result, array_ops.ones((10, 1, 3))) 153 154 def test_where_shape(self): 155 @def_function.function 156 def f(): 157 a = constant_op.constant([[1.], [1.]]) 158 b = constant_op.constant([1.]) 159 result = pfor_control_flow_ops.vectorized_map( 160 lambda x: array_ops.where(x > 0, x, b), a) 161 return result.shape 162 163 self.assertAllEqual([2, 1], f()) 164 165 def test_vectorized_map_broadcasts_unit_dimensions(self): 166 convert_with_static_shape = ops.convert_to_tensor 167 convert_with_dynamic_shape = ( 168 lambda x: array_ops.placeholder_with_default(x, shape=None)) 169 170 for convert in (convert_with_static_shape, convert_with_dynamic_shape): 171 a = convert([3.1]) 172 b = convert([-2., 6., 9.]) 173 174 # One elem with leading unit dimension. 175 a_plus_1 = pfor_control_flow_ops.vectorized_map(lambda a: a + 1, a) 176 self.assertAllEqual(*self.evaluate((a_plus_1, a + 1))) 177 178 # Two elems, both with leading unit dimension. 179 a_plus_a = pfor_control_flow_ops.vectorized_map(sum, (a, a)) 180 self.assertAllEqual(*self.evaluate((a_plus_a, a + a))) 181 182 # Elem w/ unit dimension broadcast against elem with batch dim. 183 a_plus_b = pfor_control_flow_ops.vectorized_map(sum, (a, b)) 184 self.assertAllEqual(*self.evaluate((a_plus_b, a + b))) 185 186 def test_vectorized_map_example_1(self): 187 188 def outer_product(a): 189 return math_ops.tensordot(a, a, 0) 190 191 batch_size = 100 192 a = array_ops.ones((batch_size, 32, 32)) 193 c = pfor_control_flow_ops.vectorized_map(outer_product, a) 194 self.assertAllEqual((batch_size, 32, 32, 32, 32), c.shape) 195 196 def test_disable_tf_function(self): 197 def_function.run_functions_eagerly(True) 198 # vectorized_map should ignore disabling tf.functions 199 self.assertTrue(def_function.functions_run_eagerly()) 200 self.assertAllEqual([0, 1, 4, 9], 201 pfor_control_flow_ops.vectorized_map( 202 lambda x: x * x, math_ops.range(4))) 203 self.assertTrue(def_function.functions_run_eagerly()) 204 def_function.run_functions_eagerly(False) 205 206 207@test_util.run_all_in_graph_and_eager_modes 208class IndexedSlicesTest(PForTestCase): 209 210 def test_indexed_slices(self): 211 212 def loop_fn(i): 213 return indexed_slices.IndexedSlices( 214 indices=i, values=array_ops.reshape(i, [1]), dense_shape=[3, 1]) 215 216 self._test_loop_fn(loop_fn, 2) 217 218 def test_indexed_slices_components(self): 219 220 def loop_fn(i): 221 slices = indexed_slices.IndexedSlices( 222 indices=i, values=array_ops.reshape(i, [1]), dense_shape=[3, 1]) 223 # Note that returning the components inside the slice avoids 224 # densification, which may be more efficient. 225 return slices.values, slices.indices 226 227 self._test_loop_fn(loop_fn, 2) 228 229 230@test_util.run_all_in_graph_and_eager_modes 231class ReductionTest(PForTestCase): 232 233 def test_reduce(self): 234 235 def reduce_fn(p, q): 236 return math_ops.reduce_mean(p + q, axis=0) 237 238 x = random_ops.random_uniform([4, 3, 2]) 239 y = random_ops.random_uniform([4, 3, 2]) 240 241 def loop_fn(i, pfor_config): 242 x_i = array_ops.gather(x, i) 243 y_i = array_ops.gather(y, i) 244 reduced = pfor_config.reduce(reduce_fn, x_i, y_i) 245 return reduced + x_i 246 247 output = pfor_control_flow_ops.pfor(loop_fn, 4) 248 ans = reduce_fn(x, y) + x 249 output_val, ans_val = self.evaluate([output, ans]) 250 self.assertAllClose(ans_val, output_val) 251 252 def test_reduce_concat(self): 253 x = random_ops.random_uniform([8, 3]) 254 255 def loop_fn(i, pfor_config): 256 x_i = array_ops.gather(x, i) 257 vectorized_value = pfor_config.reduce_concat(x_i) 258 mean_value = math_ops.reduce_mean(vectorized_value, axis=0) 259 return x_i - mean_value 260 261 output = pfor_control_flow_ops.pfor(loop_fn, 8) 262 ans = x - math_ops.reduce_mean(x, axis=0) 263 output_val, ans_val = self.evaluate([output, ans]) 264 self.assertAllClose(ans_val, output_val) 265 266 def test_reduce_mean(self): 267 x = random_ops.random_uniform([8, 3]) 268 269 def loop_fn(i, pfor_config): 270 x_i = array_ops.gather(x, i) 271 return x_i - pfor_config.reduce_mean(x_i) 272 273 output = pfor_control_flow_ops.pfor(loop_fn, 8) 274 ans = x - math_ops.reduce_mean(x, axis=0) 275 output_val, ans_val = self.evaluate([output, ans]) 276 self.assertAllClose(ans_val, output_val) 277 278 def test_reduce_sum(self): 279 x = random_ops.random_uniform([8, 3]) 280 281 def loop_fn(i, pfor_config): 282 x_i = array_ops.gather(x, i) 283 return x_i - pfor_config.reduce_sum(x_i) 284 285 output = pfor_control_flow_ops.pfor(loop_fn, 8) 286 ans = x - math_ops.reduce_sum(x, axis=0) 287 output_val, ans_val = self.evaluate([output, ans]) 288 self.assertAllClose(ans_val, output_val) 289 290 def test_reduce_class(self): 291 x = random_ops.random_uniform([8, 3]) 292 293 class LoopFn: 294 295 def __init__(self): 296 pass 297 298 def __call__(self, i, pfor_config): 299 x_i = array_ops.gather(x, i) 300 return x_i - pfor_config.reduce_mean(x_i) 301 302 output = pfor_control_flow_ops.pfor(LoopFn(), 8) 303 ans = x - math_ops.reduce_mean(x, axis=0) 304 output_val, ans_val = self.evaluate([output, ans]) 305 self.assertAllClose(ans_val, output_val) 306 307 def test_reduce_functools_partial(self): 308 x = random_ops.random_uniform([8, 3]) 309 310 def fn(i, pfor_config, dummy=None): 311 del dummy 312 x_i = array_ops.gather(x, i) 313 return x_i - pfor_config.reduce_mean(x_i) 314 315 loop_fn = functools.partial(fn, dummy=1) 316 output = pfor_control_flow_ops.pfor(loop_fn, 8) 317 ans = x - math_ops.reduce_mean(x, axis=0) 318 output_val, ans_val = self.evaluate([output, ans]) 319 self.assertAllClose(ans_val, output_val) 320 321 def test_parallel_iterations(self): 322 x = random_ops.random_uniform([8, 3]) 323 324 def loop_fn(i, pfor_config): 325 x_i = array_ops.gather(x, i) 326 return pfor_config.reduce_sum(x_i) 327 328 with self.assertRaisesRegex(ValueError, 329 "`parallel_iterations` currently unsupported"): 330 pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2) 331 332 def test_var_loop_len(self): 333 if context.executing_eagerly(): 334 self.skipTest("Variable length not possible under eager execution.") 335 336 x = random_ops.random_uniform([8, 3]) 337 338 def loop_fn(i, pfor_config): 339 return pfor_config.reduce_sum(array_ops.gather(x, i)) 340 341 num_iters = array_ops.placeholder(dtypes.int32) 342 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 343 with self.cached_session() as sess: 344 sess.run(pfor, feed_dict={num_iters: 8}) 345 346 347@test_util.run_all_in_graph_and_eager_modes 348class BitwiseTest(PForTestCase): 349 350 def test_unary_cwise(self): 351 for op in [bitwise_ops.invert]: 352 x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32) 353 354 # pylint: disable=cell-var-from-loop 355 def loop_fn(i): 356 x1 = array_ops.gather(x, i) 357 return op(x1) 358 359 # pylint: enable=cell-var-from-loop 360 361 self._test_loop_fn(loop_fn, 3) 362 363 def test_binary_cwise(self): 364 binary_ops = [ 365 bitwise_ops.bitwise_and, 366 bitwise_ops.bitwise_or, 367 bitwise_ops.bitwise_xor, 368 bitwise_ops.left_shift, 369 bitwise_ops.right_shift, 370 ] 371 for op in binary_ops: 372 x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32) 373 y = random_ops.random_uniform([3, 5], maxval=10, dtype=dtypes.int32) 374 375 output_dtypes = [] 376 377 # pylint: disable=cell-var-from-loop 378 def loop_fn(i): 379 x1 = array_ops.gather(x, i) 380 y1 = array_ops.gather(y, i) 381 outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)] 382 del output_dtypes[:] 383 output_dtypes.extend(t.dtype for t in outputs) 384 return outputs 385 386 # pylint: enable=cell-var-from-loop 387 self._test_loop_fn(loop_fn, 3) 388 389 390@test_util.run_all_in_graph_and_eager_modes 391class ImageTest(PForTestCase): 392 393 def test_adjust_contrast(self): 394 images = random_ops.random_uniform([3, 2, 4, 4, 3]) 395 396 def loop_fn(i): 397 image = array_ops.gather(images, i) 398 return image_ops.adjust_contrast(image, 2.0) 399 400 self._test_loop_fn(loop_fn, 3) 401 402 def test_adjust_hue(self): 403 images = random_ops.random_uniform([3, 2, 4, 4, 3]) 404 405 def loop_fn(i): 406 image = array_ops.gather(images, i) 407 return image_ops.adjust_hue(image, .25) 408 409 self._test_loop_fn(loop_fn, 3) 410 411 def test_adjust_saturation(self): 412 images = random_ops.random_uniform([3, 2, 4, 4, 3]) 413 414 def loop_fn(i): 415 image = array_ops.gather(images, i) 416 return image_ops.adjust_saturation(image, 0.1) 417 418 self._test_loop_fn(loop_fn, 3) 419 420 421@test_util.run_all_in_graph_and_eager_modes 422@test_util.run_all_without_tensor_float_32("Uses matmul") 423class NNTest(PForTestCase): 424 425 def test_conv2d(self): 426 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 427 filt = random_ops.random_uniform([3, 3, 3, 7]) 428 429 def loop_fn(i): 430 x1 = array_ops.gather(x, i) 431 return nn.conv2d( 432 x1, filt, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC") 433 434 self._test_loop_fn(loop_fn, 3) 435 436 def test_conv2d_backprop_input(self): 437 x_shape = [2, 12, 12, 3] 438 filt = random_ops.random_uniform([3, 3, 3, 7]) 439 grad = random_ops.random_uniform([3, 2, 5, 5, 7]) 440 441 def loop_fn(i): 442 grad1 = array_ops.gather(grad, i) 443 return nn.conv2d_backprop_input( 444 x_shape, 445 filt, 446 grad1, 447 strides=[1, 2, 2, 1], 448 padding="VALID", 449 data_format="NHWC") 450 451 self._test_loop_fn(loop_fn, 3) 452 453 def test_conv2d_backprop_filter(self): 454 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 455 x_0 = array_ops.gather(x, 0) 456 filter_sizes = [3, 3, 3, 7] 457 grad = random_ops.random_uniform([3, 2, 5, 5, 7]) 458 459 def loop_fn(i): 460 x_i = array_ops.gather(x, i) 461 grad_i = array_ops.gather(grad, i) 462 return [ 463 nn.conv2d_backprop_filter( 464 inp, 465 filter_sizes, 466 grad_i, 467 strides=[1, 2, 2, 1], 468 padding="VALID", 469 data_format="NHWC") for inp in [x_i, x_0] 470 ] 471 472 self._test_loop_fn(loop_fn, 3) 473 474 def test_depthwise_conv2d_native(self): 475 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 476 filt = random_ops.random_uniform([3, 3, 3, 3, 2]) 477 478 def loop_fn(i): 479 x1 = array_ops.gather(x, i) 480 filt1 = array_ops.gather(filt, i) 481 return nn.depthwise_conv2d_native( 482 x1, filt1, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC") 483 484 self._test_loop_fn(loop_fn, 3) 485 486 def test_depthwise_conv2d_native_backprop_input(self): 487 x_shape = [2, 12, 12, 3] 488 filt = random_ops.random_uniform([3, 3, 3, 3, 2]) 489 grad = random_ops.random_uniform([3, 2, 5, 5, 6]) 490 491 def loop_fn(i): 492 grad1 = array_ops.gather(grad, i) 493 filt1 = array_ops.gather(filt, i) 494 return nn.depthwise_conv2d_native_backprop_input( 495 x_shape, 496 filt1, 497 grad1, 498 strides=[1, 2, 2, 1], 499 padding="VALID", 500 data_format="NHWC") 501 502 self._test_loop_fn(loop_fn, 3) 503 504 def test_depthwise_conv2d_native_backprop_filter(self): 505 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 506 filter_sizes = [3, 3, 3, 2] 507 grad = random_ops.random_uniform([3, 2, 5, 5, 6]) 508 509 def loop_fn(i): 510 x_i = array_ops.gather(x, i) 511 grad_i = array_ops.gather(grad, i) 512 return nn.depthwise_conv2d_native_backprop_filter( 513 x_i, 514 filter_sizes, 515 grad_i, 516 strides=[1, 2, 2, 1], 517 padding="VALID", 518 data_format="NHWC") 519 520 self._test_loop_fn(loop_fn, 3) 521 522 def test_depthwise_conv2d_native_nchw(self): 523 if not test_util.is_gpu_available(): 524 self.skipTest("NCHW only works on GPU") 525 x = random_ops.random_uniform([3, 2, 3, 12, 12]) 526 filt = random_ops.random_uniform([3, 3, 3, 3, 2]) 527 528 def loop_fn(i): 529 x1 = array_ops.gather(x, i) 530 filt1 = array_ops.gather(filt, i) 531 return nn.depthwise_conv2d_native( 532 x1, filt1, strides=[1, 1, 2, 2], padding="VALID", data_format="NCHW") 533 534 self._test_loop_fn(loop_fn, 3) 535 536 def test_depthwise_conv2d_native_backprop_input_nchw(self): 537 if not test_util.is_gpu_available(): 538 self.skipTest("NCHW only works on GPU") 539 x_shape = [2, 3, 12, 12] 540 filt = random_ops.random_uniform([3, 3, 3, 3, 2]) 541 grad = random_ops.random_uniform([3, 2, 6, 5, 5]) 542 543 def loop_fn(i): 544 grad1 = array_ops.gather(grad, i) 545 filt1 = array_ops.gather(filt, i) 546 return nn.depthwise_conv2d_native_backprop_input( 547 x_shape, 548 filt1, 549 grad1, 550 strides=[1, 1, 2, 2], 551 padding="VALID", 552 data_format="NCHW") 553 554 self._test_loop_fn(loop_fn, 3) 555 556 def test_depthwise_conv2d_native_backprop_filter_nchw(self): 557 if not test_util.is_gpu_available(): 558 self.skipTest("NCHW only works on GPU") 559 x = random_ops.random_uniform([3, 2, 3, 12, 12]) 560 filter_sizes = [3, 3, 3, 2] 561 grad = random_ops.random_uniform([3, 2, 6, 5, 5]) 562 563 def loop_fn(i): 564 x_i = array_ops.gather(x, i) 565 grad_i = array_ops.gather(grad, i) 566 return nn.depthwise_conv2d_native_backprop_filter( 567 x_i, 568 filter_sizes, 569 grad_i, 570 strides=[1, 1, 2, 2], 571 padding="VALID", 572 data_format="NCHW") 573 574 self._test_loop_fn(loop_fn, 3) 575 576 def test_roll(self): 577 x = random_ops.random_uniform([3, 6, 7]) 578 579 def loop_fn(i): 580 x_i = array_ops.gather(x, i) 581 return manip_ops.roll(x_i, 3, axis=1) 582 583 self._test_loop_fn(loop_fn, 3) 584 585 def test_ensure_shape(self): 586 x = random_ops.random_uniform([3, 6, 7]) 587 588 def loop_fn(i): 589 x_i = array_ops.gather(x, i) 590 return array_ops.ensure_shape(x_i, [6, 7]) 591 592 self._test_loop_fn(loop_fn, 3) 593 594 def test_loop_variant_roll_shift(self): 595 x = random_ops.random_uniform([3, 5, 6, 7]) 596 597 def loop_fn(i): 598 x_i = array_ops.gather(x, i) 599 return manip_ops.roll(x_i, [i - 2, -1, i], axis=[1, 2, 2]) 600 601 self._test_loop_fn(loop_fn, 3) 602 603 def test_loop_variant_roll_scalar_shift(self): 604 x = random_ops.random_uniform([5, 5, 6]) 605 606 def loop_fn(i): 607 x_i = array_ops.gather(x, i) 608 return manip_ops.roll(x_i, i, axis=0) 609 610 self._test_loop_fn(loop_fn, 5) 611 612 def test_avg_pool(self): 613 with backprop.GradientTape(persistent=True) as g: 614 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 615 g.watch(x) 616 ksize = [1, 3, 3, 1] 617 618 def loop_fn(i): 619 with g: 620 x1 = array_ops.gather(x, i) 621 output = nn.avg_pool( 622 x1, 623 ksize, 624 strides=[1, 2, 2, 1], 625 padding="VALID", 626 data_format="NHWC") 627 loss = nn.l2_loss(output) 628 return output, g.gradient(loss, x1) 629 630 self._test_loop_fn(loop_fn, 3) 631 632 def test_avg_pool3d(self): 633 with backprop.GradientTape(persistent=True) as g: 634 x = random_ops.random_uniform([5, 3, 7, 6, 6, 5]) 635 g.watch(x) 636 ksize = [1, 2, 2, 2, 1] 637 strides = [1, 2, 2, 2, 1] 638 639 def loop_fn(i): 640 with g: 641 x1 = array_ops.gather(x, i) 642 output = nn.avg_pool3d( 643 x1, ksize, strides=strides, padding="VALID", data_format="NDHWC") 644 loss = nn.l2_loss(output) 645 return output, g.gradient(loss, x1) 646 647 self._test_loop_fn(loop_fn, 3) 648 649 def test_max_pool(self): 650 with backprop.GradientTape(persistent=True) as g: 651 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 652 g.watch(x) 653 ksize = [1, 3, 3, 1] 654 strides = [1, 2, 2, 1] 655 656 def loop_fn(i): 657 with g: 658 x1 = array_ops.gather(x, i) 659 output = nn.max_pool( 660 x1, ksize, strides=strides, padding="VALID", data_format="NHWC") 661 loss = nn.l2_loss(output) 662 ones = array_ops.ones_like(output) 663 g.watch(ones) 664 grad = g.gradient(loss, x1, output_gradients=ones) 665 grad_grad = g.gradient(grad, ones) 666 return output, grad, grad_grad 667 668 self._test_loop_fn(loop_fn, 3) 669 670 def test_max_pool_v2(self): 671 with backprop.GradientTape(persistent=True) as g: 672 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 673 g.watch(x) 674 ksize = [1, 3, 3, 1] 675 strides = [1, 2, 2, 1] 676 677 def loop_fn(i): 678 with g: 679 x1 = array_ops.gather(x, i) 680 output = gen_nn_ops.max_pool_v2( 681 x1, ksize, strides=strides, padding="VALID", data_format="NHWC") 682 loss = nn.l2_loss(output) 683 ones = array_ops.ones_like(output) 684 g.watch(ones) 685 grad = g.gradient(loss, x1, output_gradients=ones) 686 grad_grad = g.gradient(grad, ones) 687 return output, grad, grad_grad 688 689 self._test_loop_fn(loop_fn, 3) 690 691 def test_max_pool3d(self): 692 with backprop.GradientTape(persistent=True) as g: 693 x = random_ops.random_uniform([3, 3, 2, 12, 12, 3]) 694 g.watch(x) 695 ksize = [1, 1, 3, 3, 1] 696 strides = [1, 1, 2, 2, 1] 697 698 def loop_fn(i): 699 with g: 700 x1 = array_ops.gather(x, i) 701 output = nn.max_pool3d( 702 x1, ksize, strides=strides, padding="VALID", data_format="NDHWC") 703 loss = nn.l2_loss(output) 704 ones = array_ops.ones_like(output) 705 g.watch(ones) 706 grad = g.gradient(loss, x1, output_gradients=ones) 707 grad_grad = g.gradient(grad, ones) 708 return output, grad, grad_grad 709 710 self._test_loop_fn(loop_fn, 3) 711 712 def test_fused_batch_norm(self): 713 data_formats = ["NHWC"] 714 if test.is_gpu_available(): 715 data_formats.append("NCHW") 716 for is_training in (True, False): 717 for data_format in data_formats: 718 with backprop.GradientTape(persistent=True) as g: 719 if data_format == "NCHW": 720 x = random_ops.random_uniform([3, 1, 2, 5, 5]) 721 else: 722 x = random_ops.random_uniform([3, 1, 5, 5, 2]) 723 g.watch(x) 724 scale = random_ops.random_uniform([2]) 725 g.watch(scale) 726 offset = random_ops.random_uniform([2]) 727 g.watch(offset) 728 mean = None if is_training else random_ops.random_uniform([2]) 729 variance = None if is_training else random_ops.random_uniform([2]) 730 731 # pylint: disable=cell-var-from-loop 732 def loop_fn(i): 733 with g: 734 x1 = array_ops.gather(x, i) 735 outputs = nn.fused_batch_norm( 736 x1, 737 scale, 738 offset, 739 mean=mean, 740 variance=variance, 741 epsilon=0.01, 742 data_format=data_format, 743 is_training=is_training) 744 outputs = list(outputs) 745 # We only test the first value of outputs when is_training is 746 # False. It looks like CPU and GPU have different outputs for 747 # batch_mean and batch_variance for this case. 748 if not is_training: 749 outputs[1] = constant_op.constant(0.) 750 outputs[2] = constant_op.constant(0.) 751 loss = nn.l2_loss(outputs[0]) 752 if is_training: 753 gradients = g.gradient(loss, [x1, scale, offset]) 754 else: 755 gradients = [constant_op.constant(0.)] * 3 756 return outputs + gradients 757 758 # pylint: enable=cell-var-from-loop 759 760 self._test_loop_fn(loop_fn, 3) 761 762 def test_log_softmax(self): 763 logits = random_ops.random_uniform([3, 2, 4]) 764 765 def loop_fn(i): 766 logits_i = array_ops.gather(logits, i) 767 return (nn.log_softmax(logits_i), nn.log_softmax(logits_i, axis=0), 768 nn.log_softmax(logits_i, axis=-1)) 769 770 self._test_loop_fn(loop_fn, 3) 771 772 def test_softmax(self): 773 logits = random_ops.random_uniform([3, 2, 4]) 774 775 def loop_fn(i): 776 logits_i = array_ops.gather(logits, i) 777 return (nn.softmax(logits_i), nn.softmax(logits_i, axis=0), 778 nn.softmax(logits_i, axis=-1)) 779 780 self._test_loop_fn(loop_fn, 3) 781 782 def test_softmax_cross_entropy_with_logits(self): 783 with backprop.GradientTape(persistent=True) as g: 784 logits = random_ops.random_uniform([3, 2, 4]) 785 g.watch(logits) 786 labels = random_ops.random_uniform([3, 2, 4]) 787 labels /= math_ops.reduce_sum(labels, axis=[2], keepdims=True) 788 789 def loop_fn(i): 790 with g: 791 logits_i = array_ops.gather(logits, i) 792 labels_i = array_ops.gather(labels, i) 793 loss = nn.softmax_cross_entropy_with_logits( 794 labels=labels_i, logits=logits_i) 795 total_loss = math_ops.reduce_sum(loss) 796 return loss, g.gradient(total_loss, logits_i) 797 798 self._test_loop_fn(loop_fn, 3) 799 800 def test_sparse_softmax_cross_entropy_with_logits(self): 801 logits = random_ops.random_uniform([3, 2, 4]) 802 labels = random_ops.random_uniform( 803 shape=[3, 2], maxval=4, dtype=dtypes.int32) 804 805 def loop_fn(i): 806 logits_i = array_ops.gather(logits, i) 807 labels_i = array_ops.gather(labels, i) 808 loss = nn.sparse_softmax_cross_entropy_with_logits( 809 labels=labels_i, logits=logits_i) 810 return loss 811 812 self._test_loop_fn(loop_fn, 3) 813 814 815class RandomTest(PForTestCase): 816 817 # The random values generated in the two implementations are not guaranteed to 818 # match. So we only check the returned shapes. 819 def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): 820 outputs = self._run_targets(targets1, targets2) 821 n = len(outputs) // 2 822 for i in range(n): 823 self.assertAllEqual(outputs[i].shape, outputs[i + n].shape) 824 825 def test_random_uniform(self): 826 827 def loop_fn(_): 828 return random_ops.random_uniform([3]) 829 830 self._test_loop_fn(loop_fn, 5) 831 832 def test_random_uniform_int(self): 833 834 def loop_fn(_): 835 return random_ops.random_uniform([3], maxval=1, dtype=dtypes.int32) 836 837 self._test_loop_fn(loop_fn, 5) 838 839 def test_random_standard_normal(self): 840 841 def loop_fn(_): 842 return random_ops.random_normal([3]) 843 844 self._test_loop_fn(loop_fn, 5) 845 846 def test_truncated_normal(self): 847 848 def loop_fn(_): 849 return random_ops.truncated_normal([3]) 850 851 self._test_loop_fn(loop_fn, 5) 852 853 def test_random_gamma_invariant_alpha(self): 854 855 def loop_fn(_): 856 return random_ops.random_gamma([3], alpha=[0.5]) 857 858 self._test_loop_fn(loop_fn, 5) 859 860 def test_random_gamma_varying_alpha(self): 861 alphas = math_ops.exp(random_ops.random_normal([5, 3, 2])) 862 863 def loop_fn(i): 864 alphas_i = array_ops.gather(alphas, i) 865 # Test both scalar and non-scalar params and shapes. 866 return (random_ops.random_gamma(alpha=alphas_i[0, 0], shape=[]), 867 random_ops.random_gamma(alpha=alphas_i, shape=[]), 868 random_ops.random_gamma(alpha=alphas_i[0, 0], shape=[3]), 869 random_ops.random_gamma(alpha=alphas_i, shape=[3])) 870 871 self._test_loop_fn(loop_fn, 5) 872 873 def test_random_poisson_v2_invariant_rate(self): 874 875 def loop_fn(_): 876 return random_ops.random_poisson(lam=[1.3], shape=[3]) 877 878 self._test_loop_fn(loop_fn, 5) 879 880 def test_random_poisson_v2_varying_rate(self): 881 rates = math_ops.exp(random_ops.random_normal([5, 3, 2])) 882 883 def loop_fn(i): 884 rates_i = array_ops.gather(rates, i) 885 # Test both scalar and non-scalar params and shapes. 886 return (random_ops.random_poisson(lam=rates_i[0, 0], shape=[]), 887 random_ops.random_poisson(lam=rates_i, shape=[]), 888 random_ops.random_poisson(lam=rates_i[0, 0], shape=[3]), 889 random_ops.random_poisson(lam=rates_i, shape=[3])) 890 891 self._test_loop_fn(loop_fn, 5) 892 893 def test_random_multinomial_invariant_logits(self): 894 895 def loop_fn(_): 896 return random_ops.categorical(logits=[[1., -1.]], num_samples=3) 897 898 self._test_loop_fn(loop_fn, 5) 899 900 def test_random_multinomial_varying_logits(self): 901 logits = random_ops.random_normal([5, 3, 2]) 902 903 def loop_fn(i): 904 logits_i = array_ops.gather(logits, i) 905 return random_ops.categorical(logits_i, num_samples=3) 906 907 self._test_loop_fn(loop_fn, 5) 908 909 910class StatelessRandomTest(PForTestCase): 911 912 # This test currently only tests that the vectorized and non-vectorized 913 # outputs have same shapes. This is needed since under XLA compilation, 914 # stateless random numbers can generate different random numbers. 915 # TODO(agarwal): switch to checking for actual values matching once 916 # b/149402339 is resolved. 917 def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): 918 outputs = self._run_targets(targets1, targets2) 919 n = len(outputs) // 2 920 for i in range(n): 921 self.assertAllEqual(outputs[i].shape, outputs[i + n].shape) 922 923 # TODO(agarwal): add tests for other random functions 924 def test_multinomial(self): 925 seeds = [[1, 2], [3, 4]] 926 logits = random_ops.random_uniform([2, 3, 4]) 927 928 def loop_fn(i): 929 logits_0 = array_ops.gather(logits, 0) 930 logits_i = array_ops.gather(logits, i) 931 seeds_0 = array_ops.gather(seeds, 0) 932 seeds_i = array_ops.gather(seeds, i) 933 return (stateless_random_ops.stateless_categorical( 934 logits=logits_i, num_samples=3, seed=seeds_i), 935 stateless_random_ops.stateless_categorical( 936 logits=logits_i, num_samples=3, seed=seeds_0), 937 stateless_random_ops.stateless_categorical( 938 logits=logits_0, num_samples=3, seed=seeds_i), 939 stateless_random_ops.stateless_categorical( 940 logits=logits_0, num_samples=3, seed=seeds_0)) 941 942 self._test_loop_fn(loop_fn, 2) 943 944 945class LoggingTest(PForTestCase): 946 947 def test_print(self): 948 x = random_ops.random_uniform([3, 5]) 949 950 def loop_fn(i): 951 x1 = array_ops.gather(x, i) 952 return logging_ops.Print( 953 x1, [x1, "x1", array_ops.shape(x1)], summarize=10) 954 955 self._test_loop_fn(loop_fn, 3) 956 957 def test_print_v2(self): 958 x = constant_op.constant([1, 2, 3]) 959 960 def loop_fn(i): 961 x1 = array_ops.gather(x, i) 962 with ops.control_dependencies([ 963 logging_ops.print_v2( 964 x1, "x1", array_ops.shape(x1), summarize=10)]): 965 return array_ops.identity(x1) 966 967 self._test_loop_fn(loop_fn, 3) 968 969 with self.captureWritesToStream(sys.stderr) as printed: 970 self.evaluate(pfor_control_flow_ops.pfor(loop_fn, 3)) 971 self.assertIn("[1 2 3] x1 []", printed.contents()) 972 973 def test_assert(self): 974 975 def loop_fn(i): 976 return control_flow_ops.Assert(i < 10, [i, [10], [i + 1]]) 977 978 # TODO(agarwal): make this work with for_loop. 979 with session.Session() as sess: 980 sess.run(pfor_control_flow_ops.pfor(loop_fn, 3)) 981 sess.run(pfor_control_flow_ops.pfor( 982 lambda i, pfor_config: loop_fn(i), 3)) 983 984 985class TensorArrayTest(PForTestCase): 986 987 def setUp(self): 988 self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() 989 control_flow_v2_toggles.disable_control_flow_v2() 990 super(TensorArrayTest, self).setUp() 991 992 def tearDown(self): 993 if self._enabled: 994 control_flow_v2_toggles.enable_control_flow_v2() 995 super(TensorArrayTest, self).tearDown() 996 997 @test_util.run_v1_only("b/122612051") 998 def test_create_outside_and_read(self): 999 1000 ta = tensor_array_ops.TensorArray( 1001 dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1) 1002 1003 def loop_fn(i): 1004 return ta.read(i), ta.read(0) 1005 1006 self._test_loop_fn(loop_fn, 2) 1007 1008 @test_util.run_v1_only("b/122612051") 1009 def test_create_outside_and_gather(self): 1010 1011 ta = tensor_array_ops.TensorArray( 1012 dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1) 1013 1014 def loop_fn(i): 1015 return ta.gather([i]), ta.gather([0, 1]) 1016 1017 self._test_loop_fn(loop_fn, 2) 1018 1019 @test_util.run_v1_only("b/122612051") 1020 def test_create_outside_and_write_and_scatter(self): 1021 1022 t = tensor_array_ops.TensorArray(dtypes.int32, 10, clear_after_read=False) 1023 handle = t.handle 1024 1025 def loop_fn(i): 1026 ta = t.write(i + 2, 2 * i).write(i, 5) 1027 ta = ta.scatter([4 + i], [4]).scatter([6 + i, 8 + i], [6 + i, 8 + i]) 1028 return ta.flow 1029 1030 t1 = pfor_control_flow_ops.pfor(loop_fn, iters=2) 1031 out1 = tensor_array_ops.TensorArray( 1032 dtypes.int32, handle=handle, flow=t1[-1]).stack() 1033 output1 = self._run_targets(out1) 1034 1035 t2 = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, iters=2) 1036 out2 = tensor_array_ops.TensorArray( 1037 dtypes.int32, handle=handle, flow=t2[-1]).stack() 1038 output2 = self._run_targets(out2) 1039 self.assertAllClose(output2, output1) 1040 1041 @test_util.run_v1_only("b/122612051") 1042 def test_create_inside_and_write(self): 1043 1044 def loop_fn(i): 1045 # TODO(agarwal): switching the order of writes to ta1 does not work. 1046 ta1 = tensor_array_ops.TensorArray(dtypes.int32, 2).write(0, 1047 i).write(1, 1) 1048 ta2 = tensor_array_ops.TensorArray(dtypes.int32, 1).write(0, 1) 1049 return ta1.stack(), ta2.stack() 1050 1051 self._test_loop_fn(loop_fn, 3) 1052 1053 @test_util.run_v1_only("b/122612051") 1054 def test_create_inside_and_scatter(self): 1055 1056 def loop_fn(i): 1057 # TODO(agarwal): switching the order of scatter to ta1 does not work. 1058 ta1 = tensor_array_ops.TensorArray(dtypes.int32, 1059 2).scatter([0], 1060 [[i, 2]]).scatter([1], 1061 [[1, 2]]) 1062 ta2 = tensor_array_ops.TensorArray(dtypes.int32, 1063 2).scatter([0], [3]).scatter([1], [4]) 1064 return ta1.stack(), ta2.stack() 1065 1066 self._test_loop_fn(loop_fn, 3) 1067 1068 @test_util.run_v1_only("b/122612051") 1069 def test_create_inside_and_read(self): 1070 1071 def loop_fn(i): 1072 ta1 = tensor_array_ops.TensorArray( 1073 dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1) 1074 ta2 = tensor_array_ops.TensorArray( 1075 dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2) 1076 # TODO(agarwal): ta1.read(i) currently is not supported. 1077 return ta1.read(0), ta2.read(0), ta2.read(i) 1078 1079 self._test_loop_fn(loop_fn, 2) 1080 1081 @test_util.run_v1_only("b/122612051") 1082 def test_create_inside_and_gather(self): 1083 1084 def loop_fn(i): 1085 ta1 = tensor_array_ops.TensorArray( 1086 dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1) 1087 ta2 = tensor_array_ops.TensorArray( 1088 dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2) 1089 # TODO(agarwal): ta1.read(i) currently is not supported. 1090 return ta1.gather([0, 1]), ta2.gather([0, 1]), ta2.gather([i]) 1091 1092 self._test_loop_fn(loop_fn, 2) 1093 1094 @test_util.run_v1_only("b/122612051") 1095 def test_grad(self): 1096 x = random_ops.random_uniform([3, 2]) 1097 ta = tensor_array_ops.TensorArray( 1098 dtypes.float32, 3, clear_after_read=False).unstack(x) 1099 y = math_ops.square(ta.stack()) 1100 1101 def loop_fn(i): 1102 y_i = array_ops.gather(y, i) 1103 grad = gradient_ops.gradients(y_i, x)[0] 1104 return array_ops.gather(grad, i) 1105 1106 t1 = pfor_control_flow_ops.pfor(loop_fn, iters=3) 1107 # y = x * x. Hence dy/dx = 2 * x. 1108 actual_grad = 2.0 * x 1109 with session.Session() as sess: 1110 actual_grad, computed_grad = sess.run([t1, actual_grad]) 1111 self.assertAllClose(actual_grad, computed_grad) 1112 1113 1114@test_util.run_all_in_graph_and_eager_modes 1115class TensorListTest(PForTestCase): 1116 1117 def test_create_outside_and_write(self): 1118 handle1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1119 handle2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1120 1121 def loop_fn(i): 1122 h1 = list_ops.tensor_list_set_item(handle1, 0, i) 1123 h1 = list_ops.tensor_list_set_item(h1, 1, 1) 1124 h2 = list_ops.tensor_list_set_item(handle2, 0, 1) 1125 return (list_ops.tensor_list_stack(h1, dtypes.int32), 1126 list_ops.tensor_list_stack(h2, dtypes.int32)) 1127 1128 self._test_loop_fn(loop_fn, 3) 1129 1130 def _make_graph_def(self, text): 1131 ret = graph_pb2.GraphDef() 1132 text_format.Parse(text, ret) 1133 return ret 1134 1135 def test_no_fallback_with_internal_stacking(self): 1136 1137 # Create an op (really a function) that pfor definitely does not have a 1138 # converter for. Assumes pfor does not start looking up function definitions 1139 # for op-type-is-function-name calls. 1140 @def_function.function 1141 def opaque_list_fetch(x): 1142 array_ops.identity(x) 1143 return list_ops.tensor_list_get_item(x, 0, dtypes.int32) 1144 1145 external_handle = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1146 opaque_list_fetch_concrete = opaque_list_fetch.get_concrete_function( 1147 external_handle) 1148 opaque_list_fetch_name = opaque_list_fetch_concrete.name 1149 1150 def loop_fn(i): 1151 h1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1152 h1 = list_ops.tensor_list_set_item(h1, 0, i) 1153 opaque_list_fetch_concrete.add_to_graph() 1154 graph_def = self._make_graph_def(""" 1155 node { name: 'x' op: 'Placeholder' 1156 attr { key: 'dtype' value { type: DT_FLOAT } }} 1157 node { name: 'fn' op: '""" + opaque_list_fetch_name.decode() 1158 + """' input: 'x:0' }""") 1159 return importer.import_graph_def( 1160 graph_def, 1161 input_map={"x:0": h1}, 1162 return_elements=["fn"], 1163 name="import")[0].outputs[0] 1164 1165 with self.assertRaisesRegex(ValueError, "No pfor vectorization"): 1166 self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False) 1167 with self.assertRaisesRegex(ValueError, "No pfor vectorization"): 1168 self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True) 1169 1170 def test_create_inside_and_write(self): 1171 1172 def loop_fn(i): 1173 h1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1174 h1 = list_ops.tensor_list_set_item(h1, 0, i) 1175 h1 = list_ops.tensor_list_set_item(h1, 1, 1) 1176 h2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1177 h2 = list_ops.tensor_list_set_item(h2, 0, 1) 1178 return (list_ops.tensor_list_stack(h1, dtypes.int32), 1179 list_ops.tensor_list_stack(h2, dtypes.int32)) 1180 1181 self._test_loop_fn(loop_fn, 3) 1182 1183 def test_create_outside_and_read(self): 1184 handle = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1185 handle = list_ops.tensor_list_set_item(handle, 0, 0) 1186 handle = list_ops.tensor_list_set_item(handle, 1, 1) 1187 1188 def loop_fn(i): 1189 return (list_ops.tensor_list_get_item(handle, i, dtypes.int32), 1190 list_ops.tensor_list_get_item(handle, 0, dtypes.int32), 1191 list_ops.tensor_list_length(handle), 1192 list_ops.tensor_list_element_shape(handle, dtypes.int32), 1193 list_ops.tensor_list_element_shape(handle, dtypes.int64)) 1194 1195 self._test_loop_fn(loop_fn, 2) 1196 1197 def test_create_inside_and_read(self): 1198 1199 def loop_fn(i): 1200 handle = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1201 handle = list_ops.tensor_list_set_item(handle, 0, i) 1202 handle = list_ops.tensor_list_set_item(handle, 1, 1) 1203 return (list_ops.tensor_list_get_item(handle, 0, dtypes.int32), 1204 list_ops.tensor_list_get_item(handle, i, dtypes.int32), 1205 list_ops.tensor_list_length(handle), 1206 list_ops.tensor_list_element_shape(handle, dtypes.int32), 1207 list_ops.tensor_list_element_shape(handle, dtypes.int64)) 1208 1209 self._test_loop_fn(loop_fn, 2) 1210 1211 def test_create_outside_and_push_back(self): 1212 h = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1213 1214 def loop_fn(i): 1215 handle = list_ops.tensor_list_push_back(h, [i, 2]) 1216 handle = list_ops.tensor_list_push_back(handle, [1, 2]) 1217 handle = list_ops.tensor_list_push_back(handle, [1, 2]) 1218 return list_ops.tensor_list_stack(handle, dtypes.int32) 1219 1220 self._test_loop_fn(loop_fn, 3) 1221 1222 def test_create_inside_and_push_back(self): 1223 1224 def loop_fn(i): 1225 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1226 handle = list_ops.tensor_list_push_back(handle, [i, 2]) 1227 handle = list_ops.tensor_list_push_back(handle, [1, 2]) 1228 return list_ops.tensor_list_stack(handle, dtypes.int32) 1229 1230 self._test_loop_fn(loop_fn, 3) 1231 1232 def test_pop_back_no_shape(self): 1233 1234 def loop_fn(i): 1235 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1236 handle = list_ops.tensor_list_push_back(handle, [1, 2]) 1237 handle = list_ops.tensor_list_push_back(handle, [i, 2]) 1238 handle, tensor = list_ops.tensor_list_pop_back(handle, dtypes.int32) 1239 return tensor, list_ops.tensor_list_stack(handle, dtypes.int32) 1240 1241 self._test_loop_fn(loop_fn, 3) 1242 1243 def test_pop_back_no_shape_capture(self): 1244 h = list_ops.tensor_list_reserve([2], 1, dtypes.int32) 1245 h = list_ops.tensor_list_push_back(h, [1, 2]) 1246 1247 def loop_fn(i): 1248 handle, tensor = list_ops.tensor_list_pop_back(h, dtypes.int32) 1249 handle = list_ops.tensor_list_push_back(handle, [1, i]) 1250 return tensor, list_ops.tensor_list_stack(handle, dtypes.int32) 1251 1252 self._test_loop_fn(loop_fn, 3) 1253 1254 def test_pop_back_with_shape(self): 1255 1256 @def_function.function 1257 def loop_fn(i): 1258 with backprop.GradientTape() as tape: 1259 handle = list_ops.tensor_list_reserve(None, 1, dtypes.float32) 1260 x = math_ops.cast(i, dtypes.float32)[None] 1261 tape.watch(x) 1262 handle = list_ops.tensor_list_push_back(handle, x) 1263 stacked = list_ops.tensor_list_stack(handle, dtypes.float32) 1264 list_grad = tape.gradient(stacked, x, x) 1265 self.assertEqual("TensorListPopBack", list_grad.op.type) 1266 return list_grad, stacked, list_grad.op.inputs[1] 1267 1268 self._test_loop_fn(loop_fn, 3) 1269 1270 def test_create_outside_and_scatter(self): 1271 h = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1272 1273 def loop_fn(i): 1274 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h) 1275 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1276 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1277 return list_ops.tensor_list_stack(handle, dtypes.int32) 1278 1279 self._test_loop_fn(loop_fn, 3) 1280 1281 def test_create_inside_and_scatter(self): 1282 1283 def loop_fn(i): 1284 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1285 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle) 1286 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1287 return list_ops.tensor_list_stack(handle, dtypes.int32) 1288 1289 self._test_loop_fn(loop_fn, 3) 1290 1291 def test_loop_variant_scatter_indices(self): 1292 1293 def loop_fn(i): 1294 handle = list_ops.tensor_list_reserve([2], 10, dtypes.int32) 1295 handle = list_ops.tensor_list_scatter( 1296 [[1, i], [i + 1, 2]], 1297 [i, i + 5], input_handle=handle) 1298 return list_ops.tensor_list_stack(handle, dtypes.int32) 1299 1300 self._test_loop_fn(loop_fn, 5) 1301 1302 def test_loop_variant_scatter_duplicate_indices(self): 1303 if test_util.is_gpu_available(): 1304 self.skipTest( 1305 "Flaky in some GPU configurations due to TensorScatterNdUpdate " 1306 "nondeterminism.") 1307 1308 def loop_fn(i): 1309 handle = list_ops.tensor_list_reserve([2], 10, dtypes.int32) 1310 handle = list_ops.tensor_list_scatter( 1311 [[1, i], [1, i + 1], [i + 2, 3]], 1312 [i, i, i + 2], input_handle=handle) 1313 return list_ops.tensor_list_stack(handle, dtypes.int32) 1314 1315 self._test_loop_fn(loop_fn, 5) 1316 1317 def test_create_outside_and_gather(self): 1318 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1319 handle = list_ops.tensor_list_scatter([[2, 3]], [0], input_handle=handle) 1320 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1321 1322 def loop_fn(i): 1323 return (list_ops.tensor_list_gather(handle, [0, 1], dtypes.int32), 1324 list_ops.tensor_list_gather(handle, [i], dtypes.int32)) 1325 1326 self._test_loop_fn(loop_fn, 2) 1327 1328 def test_create_inside_and_gather(self): 1329 1330 def loop_fn(i): 1331 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1332 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle) 1333 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1334 return (list_ops.tensor_list_gather(handle, [0, 1], dtypes.int32), 1335 list_ops.tensor_list_gather(handle, [i], dtypes.int32)) 1336 1337 self._test_loop_fn(loop_fn, 2) 1338 1339 def test_create_inside_and_concat(self): 1340 1341 def loop_fn(i): 1342 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1343 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle) 1344 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1345 return gen_list_ops.tensor_list_concat_v2( 1346 handle, 1347 element_dtype=dtypes.int32, 1348 element_shape=[2], 1349 leading_dims=[]) 1350 1351 output = pfor_control_flow_ops.pfor(loop_fn, 2) 1352 self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0]) 1353 self.assertAllClose([[2, 2], [2, 2]], output[1]) 1354 1355 def test_create_outside_and_concat(self): 1356 h = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1357 1358 def loop_fn(i): 1359 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h) 1360 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1361 return gen_list_ops.tensor_list_concat_v2( 1362 handle, 1363 element_dtype=dtypes.int32, 1364 element_shape=[2], 1365 leading_dims=[]) 1366 1367 output = pfor_control_flow_ops.pfor(loop_fn, 2) 1368 self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0]) 1369 self.assertAllClose([[2, 2], [2, 2]], output[1]) 1370 1371 def test_tensor_list_from_tensor(self): 1372 t = random_ops.random_uniform([2, 3, 4]) 1373 1374 def loop_fn(i): 1375 handle = list_ops.tensor_list_from_tensor(array_ops.gather(t, i), [4]) 1376 return list_ops.tensor_list_stack(handle, t.dtype) 1377 1378 self._test_loop_fn(loop_fn, 2) 1379 1380 @test_util.enable_control_flow_v2 1381 def test_tensor_list_reserve_while_loop(self): 1382 # Here a loop invariant TensorList is captured by a while_loop, which then 1383 # performs loop dependent operations on it, resulting in a loop variant 1384 # output. This forces stacking of the variant handle captured by the 1385 # while_loop. 1386 # We handle this particular case by forcing vectorization of 1387 # TensorListReserve operation. 1388 1389 def loop_fn(i): 1390 handle = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1391 _, out_handle = control_flow_ops.while_loop( 1392 lambda j, _: j < 2, lambda j, h: 1393 (j + 1, list_ops.tensor_list_set_item(h, j, i)), (0, handle)) 1394 return list_ops.tensor_list_stack(out_handle, dtypes.int32) 1395 1396 self._test_loop_fn(loop_fn, 2) 1397 1398 @test_util.enable_control_flow_v2 1399 def test_tensor_list_while_loop_stacked_cond_stacked_list(self): 1400 1401 def loop_fn(i): 1402 handle = list_ops.tensor_list_from_tensor([20, 21, 22, 23, i], []) 1403 _, out_handle = control_flow_ops.while_loop( 1404 lambda j, _: j < i, 1405 lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)), 1406 (0, handle)) 1407 return list_ops.tensor_list_stack(out_handle, dtypes.int32) 1408 1409 self._test_loop_fn(loop_fn, 5) 1410 1411 @test_util.enable_control_flow_v2 1412 def test_tensor_list_while_loop_stacked_cond_stacked_list_unknown_shape(self): 1413 1414 def loop_fn(i): 1415 handle = list_ops.tensor_list_reserve(None, 5, dtypes.int32) 1416 _, handle = control_flow_ops.while_loop( 1417 lambda j, _: j < 5, 1418 lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, 0)), 1419 (0, handle)) 1420 _, out_handle = control_flow_ops.while_loop( 1421 lambda j, _: j < i, 1422 lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)), 1423 (0, handle)) 1424 return list_ops.tensor_list_stack(out_handle, dtypes.int32) 1425 1426 self._test_loop_fn(loop_fn, 5) 1427 1428 @test_util.enable_control_flow_v2 1429 def test_tensor_list_while_loop_stacked_cond_unstacked_list(self): 1430 1431 def loop_fn(i): 1432 handle = list_ops.tensor_list_from_tensor([20, 21, 22, 23, 24], []) 1433 _, out_handle = control_flow_ops.while_loop( 1434 lambda j, _: j < i, 1435 lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)), 1436 (0, handle)) 1437 return list_ops.tensor_list_stack(out_handle, dtypes.int32) 1438 1439 self._test_loop_fn(loop_fn, 5) 1440 1441 def test_tensor_list_addn_already_stacked(self): 1442 1443 def loop_fn(i): 1444 l1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1445 l1 = list_ops.tensor_list_set_item(l1, 0, i) 1446 l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1447 l2 = list_ops.tensor_list_set_item(l2, 1, i) 1448 return list_ops.tensor_list_stack(math_ops.add_n([l1, l2]), dtypes.int32) 1449 1450 self._test_loop_fn(loop_fn, 2) 1451 1452 def test_tensor_list_addn_stacking_required(self): 1453 l1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1454 l1 = list_ops.tensor_list_set_item(l1, 1, 1) 1455 1456 def loop_fn(i): 1457 l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1458 l2 = list_ops.tensor_list_set_item(l2, 1, i) 1459 return list_ops.tensor_list_stack( 1460 math_ops.add_n([l1, l2]), dtypes.int32) 1461 1462 self._test_loop_fn(loop_fn, 2) 1463 1464 1465@test_util.run_all_in_graph_and_eager_modes 1466class TensorTest(PForTestCase): 1467 1468 def test_loop_variant_scatter_update_no_shape(self): 1469 if test_util.is_gpu_available(): 1470 self.skipTest( 1471 "Flaky in some GPU configurations due to TensorScatterNdUpdate " 1472 "nondeterminism.") 1473 1474 @def_function.function(input_signature=[ 1475 tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32), 1476 tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32), 1477 tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32) 1478 ]) 1479 def shapeless_func(tensor, indices, updates): 1480 return array_ops.tensor_scatter_nd_update(tensor, indices, updates) 1481 1482 def loop_fn(i): 1483 tensor = [0, 0, 0, 0, 0, 0, 0, 0] 1484 indices = [[i], [i + 1], [i + 3], [i + 2]] 1485 updates = [i, i - 10, i + 11, 12] 1486 return shapeless_func(tensor, indices, updates) 1487 1488 self._test_loop_fn(loop_fn, 5) 1489 1490 def test_loop_variant_scatter_update_singles(self): 1491 if test_util.is_gpu_available(): 1492 self.skipTest( 1493 "Flaky in some GPU configurations due to TensorScatterNdUpdate " 1494 "nondeterminism.") 1495 1496 def loop_fn(i): 1497 tensor = [0, 0, 0, 0, 0, 0, 0, 0] 1498 indices = [[i], [i+1], [i+3], [i+2]] 1499 updates = [i, i-10, i+11, 12] 1500 return array_ops.tensor_scatter_nd_update(tensor, indices, updates) 1501 1502 self._test_loop_fn(loop_fn, 5) 1503 1504 def test_loop_variant_scatter_update_slices(self): 1505 if test_util.is_gpu_available(): 1506 self.skipTest( 1507 "Flaky in some GPU configurations due to TensorScatterNdUpdate " 1508 "nondeterminism.") 1509 1510 def loop_fn(i): 1511 tensor = array_ops.zeros([10, 3], dtype=dtypes.int32) 1512 indices = [[i+2], [4]] 1513 updates = [[1, i*2, 3], [i+4, i-5, 6]] 1514 return array_ops.tensor_scatter_nd_update(tensor, indices, updates) 1515 1516 self._test_loop_fn(loop_fn, 5) 1517 1518 def test_loop_variant_scatter_update_multi_dim_index(self): 1519 if test_util.is_gpu_available(): 1520 self.skipTest( 1521 "Flaky in some GPU configurations due to TensorScatterNdUpdate " 1522 "nondeterminism.") 1523 1524 def loop_fn(i): 1525 tensor = array_ops.zeros([10, 3], dtype=dtypes.int32) 1526 indices = [[i+2, 1], [4, 2]] 1527 updates = [i, 5] 1528 return array_ops.tensor_scatter_nd_update(tensor, indices, updates) 1529 1530 self._test_loop_fn(loop_fn, 5) 1531 1532 def test_loop_variant_scatter_update_folded_indices(self): 1533 if test_util.is_gpu_available(): 1534 self.skipTest( 1535 "Flaky in some GPU configurations due to TensorScatterNdUpdate " 1536 "nondeterminism.") 1537 1538 def loop_fn(i): 1539 tensor = array_ops.zeros([5, 5]) 1540 indices = [ 1541 [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]], 1542 [[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]], 1543 ] 1544 updates = [ 1545 [1, i, 1, 1, 1], 1546 [1, 1, i+2, 1, i-5], 1547 ] 1548 return array_ops.tensor_scatter_nd_update(tensor, indices, updates) 1549 1550 self._test_loop_fn(loop_fn, 5) 1551 1552 1553class OptionalTest(PForTestCase): 1554 1555 def test_optional_from_value(self): 1556 1557 def loop_fn(i): 1558 o = gen_dataset_ops.optional_from_value( 1559 [i, i + 1, constant_op.constant(3)]) 1560 gen_dataset_ops.optional_none() 1561 return gen_dataset_ops.optional_get_value( 1562 o, [dtypes.int32, dtypes.int32, dtypes.int32], 1563 [[], [], []]) 1564 1565 self._test_loop_fn(loop_fn, 2) 1566 1567 1568class StackTest(PForTestCase): 1569 1570 @test_util.run_v1_only("b/122612051") 1571 def test_stack_inside_loop_invariant(self): 1572 1573 def loop_fn(_): 1574 s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) 1575 op1 = data_flow_ops.stack_push_v2(s, 1) 1576 with ops.control_dependencies([op1]): 1577 op2 = data_flow_ops.stack_push_v2(s, 2) 1578 with ops.control_dependencies([op2]): 1579 e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1580 with ops.control_dependencies([e2]): 1581 e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1582 return e1, e2 1583 1584 self._test_loop_fn(loop_fn, 2) 1585 1586 @test_util.run_v1_only("b/122612051") 1587 def test_stack_inside_push_loop_dependent(self): 1588 1589 def loop_fn(i): 1590 s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) 1591 op1 = data_flow_ops.stack_push_v2(s, i) 1592 with ops.control_dependencies([op1]): 1593 op2 = data_flow_ops.stack_push_v2(s, 2) 1594 with ops.control_dependencies([op2]): 1595 e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1596 with ops.control_dependencies([e2]): 1597 e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1598 return e1, e2 1599 1600 self._test_loop_fn(loop_fn, 2) 1601 1602 @test_util.run_v1_only("b/122612051") 1603 def test_stack_outside_pop(self): 1604 s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) 1605 op = data_flow_ops.stack_push_v2(s, 5) 1606 with ops.control_dependencies([op]): 1607 op = data_flow_ops.stack_push_v2(s, 6) 1608 with ops.control_dependencies([op]): 1609 op = data_flow_ops.stack_push_v2(s, 7) 1610 1611 def loop_fn(_): 1612 e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1613 with ops.control_dependencies([e1]): 1614 e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1615 return e1, e2 1616 1617 with ops.control_dependencies([op]): 1618 e1, e2 = pfor_control_flow_ops.pfor(loop_fn, iters=2) 1619 with ops.control_dependencies([e1, e2]): 1620 e3 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1621 v1, v2, v3 = self._run_targets([e1, e2, e3], run_init=False) 1622 self.assertAllEqual([7, 7], v1) 1623 self.assertAllEqual([6, 6], v2) 1624 self.assertAllEqual(5, v3) 1625 1626 @test_util.run_v1_only("b/122612051") 1627 def test_stack_outside_push(self): 1628 s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) 1629 1630 def loop_fn(_): 1631 return data_flow_ops.stack_push_v2(s, 7) 1632 1633 with self.assertRaisesRegex(ValueError, "StackPushV2 not allowed.*"): 1634 pfor_control_flow_ops.pfor(loop_fn, iters=2) 1635 1636 1637# TODO(agarwal): test nested while_loops. This currently requires converting a 1638# tf.cond. 1639class WhileV1Test(PForTestCase): 1640 1641 def setUp(self): 1642 self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() 1643 control_flow_v2_toggles.disable_control_flow_v2() 1644 super(WhileV1Test, self).setUp() 1645 1646 def tearDown(self): 1647 if self._enabled: 1648 control_flow_v2_toggles.enable_control_flow_v2() 1649 super(WhileV1Test, self).tearDown() 1650 1651 def test_while_outside_loop(self): 1652 1653 x = control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) 1654 1655 def loop_fn(i): 1656 return x + i 1657 1658 self._test_loop_fn(loop_fn, 3) 1659 1660 @test_util.run_v1_only("b/122612051") 1661 def test_invariant_while(self): 1662 1663 def loop_fn(_): 1664 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) 1665 1666 self._test_loop_fn(loop_fn, 3) 1667 1668 @test_util.run_v1_only("b/122612051") 1669 def test_invariant_while_with_control_dependency(self): 1670 1671 def loop_fn(i): 1672 with ops.control_dependencies([i]): 1673 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, 1674 [0]) 1675 1676 self._test_loop_fn(loop_fn, 3) 1677 1678 @test_util.run_v1_only("b/122612051") 1679 def test_while_with_stateful_ops(self): 1680 1681 def loop_fn(_): 1682 return control_flow_ops.while_loop( 1683 lambda j, x: j < 4, lambda j, x: 1684 (j + 1, x + random_ops.random_uniform([])), [0, 0.])[0] 1685 1686 self._test_loop_fn(loop_fn, 3) 1687 1688 @test_util.run_v1_only("b/122612051") 1689 def test_while_unstacked_condition(self): 1690 1691 def loop_fn(i): 1692 return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x: 1693 (j + 1, x + i), [0, 0]) 1694 1695 self._test_loop_fn(loop_fn, 3) 1696 1697 @test_util.run_v1_only("b/122612051") 1698 def test_while(self): 1699 x = random_ops.random_uniform([3, 5]) 1700 lengths = constant_op.constant([4, 0, 2]) 1701 1702 def loop_fn(i): 1703 x_i = array_ops.gather(x, i) 1704 lengths_i = array_ops.gather(lengths, i) 1705 1706 _, total = control_flow_ops.while_loop( 1707 lambda j, _: j < lengths_i, lambda j, t: 1708 (j + 1, t + array_ops.gather(x_i, j)), [0, 0.]) 1709 return total 1710 1711 self._test_loop_fn(loop_fn, 3) 1712 1713 @test_util.run_v1_only("b/122612051") 1714 def test_while_jacobian(self): 1715 x = random_ops.random_uniform([1, 3]) 1716 y = random_ops.random_uniform([3, 3]) 1717 1718 # out = x @ y @ y @ y @ y, where @ is matmul operator. 1719 _, out = control_flow_ops.while_loop( 1720 lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)), 1721 [0, x]) 1722 1723 def loop_fn(i): 1724 out_i = array_ops.gather(out, i, axis=1) 1725 return array_ops.reshape(gradient_ops.gradients(out_i, x)[0], [-1]) 1726 1727 out = pfor_control_flow_ops.pfor(loop_fn, iters=3) 1728 1729 # The above code does not work with tf.while_loop instead of pfor. So we 1730 # manually compute the expected output here. 1731 # Note that gradient of output w.r.t is (y @ y @ y @ y)^T. 1732 expected_output = y 1733 for _ in range(3): 1734 expected_output = math_ops.matmul(expected_output, y) 1735 expected_output = array_ops.transpose(expected_output, [1, 0]) 1736 1737 with session.Session() as sess: 1738 out, expected = sess.run([out, expected_output]) 1739 self.assertAllClose(expected, out) 1740 1741 @test_util.run_v1_only("b/122612051") 1742 def test_tensor_array_as_loop_variable(self): 1743 1744 def loop_fn(i): 1745 1746 def body(j, ta): 1747 ta = ta.write(j, i + j * j) 1748 return j + 1, ta 1749 1750 _, ta = control_flow_ops.while_loop( 1751 lambda j, _: j < 4, body, 1752 (0, tensor_array_ops.TensorArray(dtypes.int32, size=4))) 1753 return ta.stack() 1754 1755 self._test_loop_fn(loop_fn, 3) 1756 1757 @test_util.run_v1_only("b/122612051") 1758 def test_read_tensor_array_partitioned_indices(self): 1759 # Note that tensor array values are pfor loop dependent, and the while loop 1760 # termination condition is also dependent on pfor iteration. 1761 def loop_fn(i): 1762 ta = tensor_array_ops.TensorArray(dtypes.int32, size=6) 1763 ta = ta.unstack(i + list(range(5))) 1764 1765 def body(j, s): 1766 return j + 1, s + ta.read(j) 1767 1768 _, s = control_flow_ops.while_loop(lambda j, _: j < i, body, (0, 0)) 1769 return s 1770 1771 self._test_loop_fn(loop_fn, 3) 1772 1773 @test_util.run_v1_only("b/122612051") 1774 def test_external_while_loop_grad(self): 1775 # Here we test that external while_loops that are extended from inside pfor 1776 # (due to gradient calls) are not actually converted. If the below was 1777 # converted all pfor iterations would write to the same tensor array 1778 # indices. 1779 x = constant_op.constant(1.) 1780 1781 def body(j, ta): 1782 ta = ta.write(j, x) 1783 return j + 1, ta 1784 1785 _, ta = control_flow_ops.while_loop( 1786 lambda j, _: j < 4, body, 1787 (0, tensor_array_ops.TensorArray(dtypes.float32, size=4))) 1788 out = ta.stack() 1789 1790 def loop_fn(i): 1791 out_i = array_ops.gather(out, i) 1792 return gradient_ops.gradients(out_i, x)[0] 1793 1794 with session.Session() as sess: 1795 # out is [x, x, x]. Hence the gradients should be [1, 1, 1]. 1796 self.assertAllEqual([1, 1, 1], 1797 sess.run(pfor_control_flow_ops.pfor(loop_fn, 3))) 1798 1799 @test_util.run_v1_only("b/122612051") 1800 def test_tensor_array_grad(self): 1801 inp = constant_op.constant(np.random.rand(3, 4, 2), dtype=dtypes.float32) 1802 ta = tensor_array_ops.TensorArray(dtypes.float32, size=3) 1803 ta = ta.unstack(inp) 1804 1805 def loop_fn(i): 1806 1807 def body(j, x): 1808 value = ta.gather([j]) 1809 value = array_ops.gather(array_ops.reshape(value, [4, 2]), i) 1810 return j + 1, x + value 1811 1812 _, out = control_flow_ops.while_loop(lambda j, _: j < 3, body, 1813 (0, array_ops.zeros([2]))) 1814 out = math_ops.reduce_prod(out) 1815 return out, gradient_ops.gradients(out, inp)[0] 1816 1817 pfor_out, pfor_out_grad = pfor_control_flow_ops.pfor(loop_fn, 4) 1818 # Note that tf.while_loop does not work in the setup above. So we manually 1819 # construct the equivalent computation of the above loops here. 1820 real_out = math_ops.reduce_sum(inp, axis=[0]) 1821 real_out = math_ops.reduce_prod(real_out, axis=[1]) 1822 # Note that gradients of real_out will accumulate the gradients across the 1823 # output value. Hence we do the same aggregation on pfor_out_grad. 1824 real_out_grad = gradient_ops.gradients(real_out, inp)[0] 1825 sum_pfor_out_grad = math_ops.reduce_sum(pfor_out_grad, axis=[0]) 1826 1827 with session.Session() as sess: 1828 v1, v2, v1_grad, v2_grad = sess.run( 1829 [pfor_out, real_out, sum_pfor_out_grad, real_out_grad]) 1830 self.assertAllClose(v1, v2) 1831 self.assertAllClose(v1_grad, v2_grad) 1832 1833 1834def dynamic_lstm_input_fn(batch_size, state_size, max_steps): 1835 # We make inputs and sequence_length constant so that multiple session.run 1836 # calls produce the same result. 1837 inputs = constant_op.constant( 1838 np.random.rand(batch_size, max_steps, state_size), dtype=dtypes.float32) 1839 sequence_length = np.random.randint(0, size=[batch_size], high=max_steps + 1) 1840 sequence_length = constant_op.constant(sequence_length, dtype=dtypes.int32) 1841 return inputs, sequence_length 1842 1843 1844def create_dynamic_lstm(cell_fn, batch_size, state_size, max_steps): 1845 cell = cell_fn(state_size) 1846 inputs, sequence_length = dynamic_lstm_input_fn(batch_size, state_size, 1847 max_steps) 1848 inputs_ta = tensor_array_ops.TensorArray( 1849 dtypes.float32, size=max_steps, element_shape=[batch_size, state_size]) 1850 inputs_time_major = array_ops.transpose(inputs, [1, 0, 2]) 1851 inputs_ta = inputs_ta.unstack(inputs_time_major) 1852 zeros = array_ops.zeros([state_size]) 1853 1854 def loop_fn(i): 1855 sequence_length_i = array_ops.gather(sequence_length, i) 1856 1857 def body_fn(t, state, ta): 1858 inputs_t = array_ops.expand_dims( 1859 array_ops.gather(inputs_ta.read(t), i), 0) 1860 output, new_state = cell(inputs_t, state) 1861 output = array_ops.reshape(output, [-1]) 1862 # TODO(agarwal): one optimization that dynamic_rnn uses is to avoid the 1863 # array_ops.where when t < min(sequence_length). Doing that requires 1864 # supporting tf.cond pfor conversion. 1865 done = t >= sequence_length_i 1866 output = array_ops.where(done, zeros, output) 1867 ta = ta.write(t, output) 1868 new_state = [ 1869 array_ops.where(done, s, ns) 1870 for s, ns in zip(nest.flatten(state), nest.flatten(new_state)) 1871 ] 1872 new_state = nest.pack_sequence_as(state, new_state) 1873 return t + 1, new_state, ta 1874 1875 def condition_fn(t, _, unused): 1876 del unused 1877 return t < max_steps 1878 1879 initial_state = cell.zero_state(1, dtypes.float32) 1880 _, state, ta = control_flow_ops.while_loop(condition_fn, body_fn, [ 1881 0, initial_state, 1882 tensor_array_ops.TensorArray(dtypes.float32, max_steps) 1883 ]) 1884 1885 new_state = [array_ops.reshape(x, [-1]) for x in nest.flatten(state)] 1886 new_state = nest.pack_sequence_as(initial_state, new_state) 1887 return ta.stack(), new_state 1888 1889 pfor_output = pfor_control_flow_ops.pfor(loop_fn, batch_size) 1890 tf_output = rnn.dynamic_rnn( 1891 cell, 1892 inputs, 1893 sequence_length=sequence_length, 1894 initial_state=cell.zero_state(batch_size, dtypes.float32)) 1895 return pfor_output, tf_output 1896 1897 1898@test_util.run_all_in_graph_and_eager_modes 1899class WhileV2Test(PForTestCase): 1900 1901 def setUp(self): 1902 self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() 1903 control_flow_v2_toggles.enable_control_flow_v2() 1904 super(WhileV2Test, self).setUp() 1905 1906 def tearDown(self): 1907 if not self._enabled: 1908 control_flow_v2_toggles.disable_control_flow_v2() 1909 super(WhileV2Test, self).tearDown() 1910 1911 def test_while_outside_loop(self): 1912 1913 def _f(): 1914 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) 1915 1916 def loop_fn(i): 1917 return _f() + i 1918 1919 self._test_loop_fn(loop_fn, 3) 1920 1921 def test_invariant_while(self): 1922 1923 def loop_fn(_): 1924 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) 1925 1926 self._test_loop_fn(loop_fn, 3) 1927 1928 def test_invariant_while_with_control_dependency(self): 1929 1930 def loop_fn(i): 1931 with ops.control_dependencies([i]): 1932 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, 1933 [0]) 1934 1935 self._test_loop_fn(loop_fn, 3) 1936 1937 def test_while_with_stateful_ops(self): 1938 1939 def loop_fn(_): 1940 j, _ = control_flow_ops.while_loop( 1941 lambda j, x: j < 4, lambda j, x: 1942 (j + 1, x + random_ops.random_uniform([])), [0, 0.]) 1943 return j 1944 1945 self._test_loop_fn(loop_fn, 3) 1946 1947 def test_while_with_variable(self): 1948 v = resource_variable_ops.ResourceVariable(5.) 1949 1950 def loop_fn(_): 1951 _, output = control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x: 1952 (j + 1, x + v), [0, 0.]) 1953 return output 1954 1955 self._test_loop_fn(loop_fn, 3) 1956 1957 def test_while_unstacked_condition(self): 1958 1959 def loop_fn(i): 1960 return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x: 1961 (j + 1, x + i), [0, 0]) 1962 1963 self._test_loop_fn(loop_fn, 3) 1964 1965 def test_while(self): 1966 x = random_ops.random_uniform([3, 5]) 1967 lengths = constant_op.constant([4, 0, 2]) 1968 1969 def loop_fn(i): 1970 x_i = array_ops.gather(x, i) 1971 lengths_i = array_ops.gather(lengths, i) 1972 1973 return control_flow_ops.while_loop( 1974 lambda j, _: j < lengths_i, lambda j, t: 1975 (j + 1, t + array_ops.gather(x_i, j)), [0, 0.]) 1976 1977 self._test_loop_fn(loop_fn, 3) 1978 1979 def test_while_change_input_invariance(self): 1980 # This tests cases where a loop invariant input to while has loop dependent 1981 # operations applied to it inside the while body. 1982 # It also test inputs that are passed through. 1983 def loop_fn(i): 1984 return control_flow_ops.while_loop( 1985 lambda j, *_: j < i, lambda j, x, y, z, w: 1986 (j + 1, x + i, y + x, z, w), [ 1987 0, 1988 constant_op.constant(0), 1989 constant_op.constant(1), i, 1990 constant_op.constant(2) 1991 ]) 1992 1993 self._test_loop_fn(loop_fn, 3) 1994 1995 def test_while_shape_invariants(self): 1996 1997 def loop_fn(i): 1998 return control_flow_ops.while_loop( 1999 lambda j, *_: j < 4, 2000 lambda j, x, y: (j + 1, x + i, y + 1), 2001 [0, constant_op.constant([0, 1]), 2002 constant_op.constant([2, 3])], 2003 shape_invariants=[ 2004 None, 2005 tensor_shape.TensorShape([2]), 2006 tensor_shape.TensorShape([2]) 2007 ]) 2008 2009 self._test_loop_fn(loop_fn, 3) 2010 2011 def test_while_jacobian(self): 2012 # Note that we wrap the code below in a tf.function since we don't want the 2013 # while_loop call to be evaluated eagerly using a python loop. 2014 @def_function.function 2015 def _f(x, y, use_pfor): 2016 # out = x @ y @ y @ y @ y, where @ is matmul operator. 2017 _, out = control_flow_ops.while_loop( 2018 lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)), 2019 [0, x]) 2020 2021 def loop_fn(i): 2022 out_i = array_ops.gather(out, i, axis=1) 2023 grad = gradient_ops.gradients(out_i, x) 2024 return array_ops.reshape(grad[0], [-1]) 2025 2026 if use_pfor: 2027 return pfor_control_flow_ops.pfor(loop_fn, iters=3) 2028 else: 2029 return pfor_control_flow_ops.for_loop( 2030 loop_fn, iters=3, loop_fn_dtypes=out.dtype) 2031 2032 x = constant_op.constant(np.random.uniform(size=(1, 3))) 2033 y = constant_op.constant(np.random.uniform(size=(3, 3))) 2034 self.assertAllClose(_f(x, y, True), _f(x, y, False)) 2035 2036 def test_scan(self): 2037 np.random.seed(seed=42) 2038 data = np.random.randn(3).astype(np.float32) 2039 2040 def log_prob(x): 2041 return math_ops.reduce_sum(functional_ops.scan_v2( 2042 lambda _, yi: (x - yi)**2, 2043 elems=data, 2044 initializer=constant_op.constant(0.))) 2045 2046 x = variables.Variable(array_ops.ones([2])) 2047 self.evaluate(x.initializer) 2048 v_log_prob = lambda x: pfor_control_flow_ops.vectorized_map(log_prob, x) 2049 theoretical, numerical = gradient_checker_v2.compute_gradient( 2050 v_log_prob, (x,), delta=1e-3) 2051 self.assertAllClose(theoretical, numerical, rtol=1e-2) 2052 2053 def test_scan_captured_variable(self): 2054 if not context.executing_eagerly(): 2055 self.skipTest("Test only written for 2.x") 2056 v = variables.Variable(math_ops.range(10, dtype=dtypes.float32)) 2057 2058 def loop_fn(idx): 2059 del idx 2060 return functional_ops.scan_v2(lambda _, i: array_ops.gather(v, i), 2061 elems=math_ops.range(v.shape[0]), 2062 initializer=0.0) 2063 with backprop.GradientTape() as tape: 2064 result = pfor_control_flow_ops.pfor(loop_fn, 2) 2065 self.assertAllClose([2.] * 10, tape.gradient(result, v)) 2066 2067 2068@test_util.run_all_in_graph_and_eager_modes 2069class NestedControlFlowTest(PForTestCase): 2070 2071 def setUp(self): 2072 self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() 2073 control_flow_v2_toggles.enable_control_flow_v2() 2074 super(NestedControlFlowTest, self).setUp() 2075 2076 def tearDown(self): 2077 if not self._enabled: 2078 control_flow_v2_toggles.disable_control_flow_v2() 2079 super(NestedControlFlowTest, self).tearDown() 2080 2081 def _cond(self, f=None, split=0): 2082 if f is None: 2083 f = lambda x, y: (x, y) 2084 2085 def _f(x, y): 2086 return control_flow_ops.cond(y > split, lambda: f(x, y), lambda: 2087 (x + 1., y)) 2088 2089 return _f 2090 2091 def _while(self, f=None): 2092 if f is None: 2093 f = lambda x, y: (x, y) 2094 2095 def _f(x, y): 2096 return control_flow_ops.while_loop( 2097 lambda j, _: j < y, lambda j, t: 2098 (j + 1, t + array_ops.gather(f(x, y)[0], j)), [0, x])[1], y 2099 2100 return _f 2101 2102 def _test_helper(self, f): 2103 x = random_ops.random_uniform([5, 5]) 2104 y = constant_op.constant([4, -1, 2, -2, 2]) 2105 2106 def loop_fn(i): 2107 x_i = array_ops.gather(x, i) 2108 y_i = array_ops.gather(y, i) 2109 return f(x_i, y_i) 2110 2111 self._test_loop_fn(loop_fn, 5) 2112 2113 def test_cond_while(self): 2114 self._test_helper(self._cond(self._while())) 2115 2116 def test_while_cond(self): 2117 self._test_helper(self._while(self._cond())) 2118 2119 def test_while_while(self): 2120 self._test_helper(self._while(self._while())) 2121 2122 def test_cond_cond(self): 2123 self._test_helper(self._cond(self._cond())) 2124 2125 2126@test_util.run_all_in_graph_and_eager_modes 2127@test_util.with_control_flow_v2 2128class StatelessIfTest(PForTestCase): 2129 2130 def test_loop_variant_cond(self): 2131 x = [1, 2, 3, 4, 5.] 2132 y = 2.5 2133 2134 @def_function.function 2135 def loop_fn(i): 2136 x_i = array_ops.gather(x, i) 2137 # Note that the output has a combination of then and else branches being 2138 # loop variant / invariant. 2139 return cond_v2.cond_v2(x_i < y, lambda: (y - x_i, y, 1., 2.), lambda: 2140 (x_i - y, 0., y, 3.)) 2141 2142 self._test_loop_fn(loop_fn, iters=5) 2143 2144 def test_loop_invariant_cond(self): 2145 x = [1, 2, 3, 4, 5.] 2146 y = 0.5 2147 z = random_ops.random_uniform([]) 2148 2149 @def_function.function 2150 def loop_fn(i): 2151 x_i = array_ops.gather(x, i) 2152 # Note that the output has a combination of then and else branches being 2153 # loop variant / invariant. 2154 return cond_v2.cond_v2(z < y, lambda: (y - x_i, y, 1., 2.), lambda: 2155 (x_i - y, 0., y, 3.)) 2156 2157 self._test_loop_fn(loop_fn, iters=5) 2158 2159 def test_empty_branch(self): 2160 x = [1, 2, 3, 4, 5.] 2161 y = 6. 2162 2163 @def_function.function 2164 def loop_fn(i): 2165 x_i = array_ops.gather(x, i) 2166 return cond_v2.cond_v2( 2167 x_i < y, # Note that else branch is empty. 2168 lambda: (y - x_i, y, 1., 2.), 2169 lambda: (x_i - y, 0., y, 3.)) 2170 2171 self._test_loop_fn(loop_fn, iters=5) 2172 2173 2174@test_util.run_all_in_graph_and_eager_modes 2175@test_util.with_control_flow_v2 2176class IfTest(PForTestCase): 2177 2178 def test_read_var(self): 2179 self.skipTest("b/156438918") # Flaky 2180 2181 x = [1, 2, 3, 4, 5.] 2182 y = 2.5 2183 z = resource_variable_ops.ResourceVariable(5.) 2184 2185 @def_function.function 2186 def loop_fn(i): 2187 x_i = array_ops.gather(x, i) 2188 return cond_v2.cond_v2(x_i < y, lambda: z - x_i, lambda: z + x_i) 2189 2190 self._test_loop_fn(loop_fn, iters=5) 2191 2192 2193class RNNTest(PForTestCase): 2194 2195 @test_util.run_v1_only("b/122612051") 2196 def test_dynamic_rnn(self): 2197 pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, 3, 5, 2198 7) 2199 self.run_and_assert_equal(pfor_outputs, tf_outputs) 2200 2201 @test_util.run_v1_only("b/122612051") 2202 def test_dynamic_lstm(self): 2203 pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicLSTMCell, 3, 5, 2204 7) 2205 self.run_and_assert_equal(pfor_outputs, tf_outputs) 2206 2207 2208# TODO(agarwal): benchmark numbers on GPU for graphs based on while_loop 2209# conversion don't look good. Some of it seems like lot of copies between host 2210# and device. Optimize that. 2211class Benchmarks(test.Benchmark): 2212 2213 def _run(self, targets, iters, name=None): 2214 2215 def _done(t): 2216 # Note that we don't use tf.control_dependencies since that will not make 2217 # sure that the computation on GPU has actually finished. So we fetch the 2218 # first element of the output, and assume that this will not be called on 2219 # empty tensors. 2220 return array_ops.gather(array_ops.reshape(t, [-1]), 0) 2221 2222 targets = [_done(x) for x in nest.flatten(targets)] 2223 sess = session.Session() 2224 with sess: 2225 init = variables.global_variables_initializer() 2226 sess.run(init) 2227 run_fn = sess.make_callable(targets) 2228 run_fn() # Warm up 2229 begin = time.time() 2230 for _ in range(iters): 2231 run_fn() 2232 end = time.time() 2233 avg_time_ms = 1000 * (end - begin) / iters 2234 self.report_benchmark(iters=iters, wall_time=avg_time_ms, name=name) 2235 return avg_time_ms 2236 2237 def benchmark_sess_run_overhead(self): 2238 with ops.Graph().as_default(): 2239 x = constant_op.constant(1.0) 2240 self._run(x, 10000, name="session_run_overhead") 2241 2242 def benchmark_add(self): 2243 with ops.Graph().as_default(): 2244 n = 256 2245 params = 1000 2246 x = random_ops.random_normal([n, params]) 2247 y = random_ops.random_normal([n, params]) 2248 2249 def loop_fn(i): 2250 x_i = array_ops.gather(x, i) 2251 y_i = array_ops.gather(y, i) 2252 return x_i + y_i 2253 2254 pfor_outputs = pfor_control_flow_ops.pfor(loop_fn, n) 2255 while_outputs = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, n) 2256 manual = x + y 2257 2258 self._run(manual, 1000, name="manual_add") 2259 self._run(pfor_outputs, 1000, name="pfor_add") 2260 self._run(while_outputs, 100, name="while_add") 2261 2262 def benchmark_matmul(self): 2263 with ops.Graph().as_default(): 2264 n = 1024 2265 params = 1000 2266 x = random_ops.random_normal([n, params]) 2267 y = random_ops.random_normal([params, params]) 2268 2269 def loop_fn(i): 2270 x_i = array_ops.expand_dims(array_ops.gather(x, i), 0) 2271 return math_ops.matmul(x_i, y) 2272 2273 pfor_outputs = pfor_control_flow_ops.pfor(loop_fn, n) 2274 while_outputs = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, n) 2275 manual = math_ops.matmul(x, y) 2276 2277 self._run(manual, 1000, name="manual_matmul") 2278 self._run(pfor_outputs, 1000, name="pfor_matmul") 2279 self._run(while_outputs, 100, name="while_matmul") 2280 2281 def benchmark_map_fn(self): 2282 with ops.Graph().as_default(): 2283 b = 256 2284 params = 1000 2285 inp = random_ops.random_normal((b, params)) 2286 fn = lambda x: x * x 2287 2288 def pfor_map_fn(f, x): 2289 return pfor_control_flow_ops.pfor(lambda i: f(array_ops.gather(x, i)), 2290 array_ops.shape(x)[0]) 2291 2292 map_output = map_fn.map_fn(fn, inp) 2293 pfor_output = pfor_map_fn(fn, inp) 2294 2295 self._run(map_output, 100, name="tf_map_fn") 2296 self._run(pfor_output, 100, name="pfor_map_fn") 2297 2298 def benchmark_basic_while(self): 2299 with ops.Graph().as_default(): 2300 2301 def loop_fn(i): 2302 _, s = control_flow_ops.while_loop(lambda t, x: t < i, lambda t, x: 2303 (t + 1, x + i), [0, 0]) 2304 return s 2305 2306 iters = 50 2307 pfor_output = pfor_control_flow_ops.pfor(loop_fn, iters) 2308 for_loop_output = pfor_control_flow_ops.for_loop(loop_fn, dtypes.int32, 2309 iters) 2310 self._run(pfor_output, 100, name="pfor_basic") 2311 self._run(for_loop_output, 100, name="for_loop_basic") 2312 2313 def benchmark_dynamic_rnn(self): 2314 with ops.Graph().as_default(): 2315 pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, 128, 2316 512, 16) 2317 self._run(pfor_outputs, 100, name="pfor_rnn") 2318 self._run(tf_outputs, 100, name="tf_rnn") 2319 2320 def benchmark_reduction(self): 2321 n = 1024 2322 with ops.Graph().as_default(): 2323 x = random_ops.random_uniform([n, n]) 2324 w = random_ops.random_uniform([n, n]) 2325 2326 def loop_fn(i, pfor_config): 2327 x_i = array_ops.gather(x, i) 2328 return math_ops.reduce_sum( 2329 math_ops.matmul(pfor_config.reduce_concat(x_i), w)) 2330 2331 # Note that output_reduction will be tiled, so there may be some minor 2332 # overheads compared to output_no_reduction. 2333 output_reduction = pfor_control_flow_ops.pfor(loop_fn, n) 2334 output_no_reduction = math_ops.reduce_sum(math_ops.matmul(x, w)) 2335 # Benchmark to test that reduction does not add overhead and its output is 2336 # treated as loop invariant. 2337 self._run(output_reduction, 30, name="matmul_reduction") 2338 self._run(output_no_reduction, 30, name="matmul_no_reduction") 2339 2340 2341class SparseTest(PForTestCase): 2342 2343 @test_util.run_v1_only("b/122612051") 2344 def test_var_loop_len(self): 2345 num_iters = array_ops.placeholder(dtypes.int32) 2346 2347 def loop_fn(_): 2348 return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6], 2349 [3]) # [0, 2, 0] 2350 2351 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2352 with self.cached_session() as sess: 2353 sess.run(pfor, feed_dict={num_iters: 3}) 2354 2355 @test_util.run_v1_only("b/122612051") 2356 def test_sparse_result_none_stacked(self): 2357 num_iters = 10 2358 2359 def loop_fn(_): 2360 return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6], 2361 [3]) # [0, 2, 0] 2362 2363 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2364 2365 indices = [[i, j] for i in range(num_iters) for j in range(3)] 2366 values = [4, 5, 6] * num_iters 2367 dense_shapes = [num_iters, 3] 2368 # Expected result: [[4, 5, 6], [4, 5, 6], [4, 5, 6], ...] 2369 manual = sparse_tensor.SparseTensor(indices, values, dense_shapes) 2370 self.run_and_assert_equal(pfor, manual) 2371 2372 @test_util.run_v1_only("b/122612051") 2373 def test_sparse_result_all_stacked(self): 2374 num_iters = 10 2375 2376 def loop_fn(i): 2377 i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) 2378 indices = array_ops.expand_dims(i, 0) 2379 return sparse_tensor.SparseTensor(indices, i, i + 1) # [0, ..., 0, i] 2380 2381 # Expected result: [[0], [0, 1], [0, 0, 2], [0, 0, 0, 3], ...] 2382 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2383 manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)], 2384 list(range(num_iters)), 2385 (num_iters, num_iters)) 2386 self.run_and_assert_equal(pfor, manual) 2387 2388 @test_util.run_v1_only("b/122612051") 2389 def test_sparse_result_indices_stacked(self): 2390 num_iters = 10 2391 2392 def loop_fn(i): 2393 i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) 2394 indices = array_ops.expand_dims(i, 0) 2395 return sparse_tensor.SparseTensor(indices, [1], [num_iters]) 2396 2397 # Expected result: identity matrix size num_iters * num_iters 2398 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2399 manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)], 2400 [1] * num_iters, (num_iters, num_iters)) 2401 self.run_and_assert_equal(pfor, manual) 2402 2403 @test_util.run_v1_only("b/122612051") 2404 def test_sparse_result_values_stacked(self): 2405 num_iters = 10 2406 2407 def loop_fn(i): 2408 i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) 2409 return sparse_tensor.SparseTensor([[0]], i, [num_iters]) # [i, 0, ..., 0] 2410 2411 # Expected result: [[1, 0, ...], [2, 0, ...], [3, 0, ...], ...] 2412 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2413 manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)], 2414 list(range(num_iters)), 2415 (num_iters, num_iters)) 2416 self.run_and_assert_equal(pfor, manual) 2417 2418 @test_util.run_v1_only("b/122612051") 2419 def test_sparse_result_shapes_stacked(self): 2420 num_iters = 10 2421 2422 def loop_fn(i): 2423 i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) 2424 return sparse_tensor.SparseTensor([[0]], [1], i + 1) # [1, 0, ..., 0] 2425 2426 # Expected result: [[1, 0, 0, ...], [1, 0, 0, ...], ...] 2427 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2428 manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)], 2429 [1] * num_iters, (num_iters, num_iters)) 2430 self.run_and_assert_equal(pfor, manual) 2431 2432 @test_util.run_v1_only("b/122612051") 2433 def test_sparse_result_shapes_stacked_2D(self): 2434 num_iters = 10 2435 2436 def loop_fn(i): 2437 i = array_ops.expand_dims(math_ops.cast(i + 1, dtypes.int64), 0) 2438 shape = array_ops.concat([i, i], 0) 2439 return sparse_tensor.SparseTensor([[0, 0]], [1], shape) # [1, 0, ..., 0] 2440 2441 # Expected result: [[[1, 0, ...], [0, ..., 0], [0, ..., 0], ...], ...] 2442 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2443 manual = sparse_tensor.SparseTensor([[i, 0, 0] for i in range(num_iters)], 2444 [1] * num_iters, 2445 (num_iters, num_iters, num_iters)) 2446 self.run_and_assert_equal(pfor, manual) 2447 2448 2449# Dummy CompositeTensor to test CompositeTensor support. 2450class Particle(composite_tensor.CompositeTensor): 2451 """A (batch of) particles each defined by a mass and a scalar velocity.""" 2452 2453 def __init__(self, mass, velocity): 2454 mass = ops.convert_to_tensor(mass) 2455 velocity = ops.convert_to_tensor(velocity) 2456 self.shape = array_ops.broadcast_static_shape(mass.shape, velocity.shape) 2457 self.mass = mass 2458 self.velocity = velocity 2459 2460 @property 2461 def _type_spec(self): 2462 return ParticleSpec( 2463 type_spec.type_spec_from_value(self.mass), 2464 type_spec.type_spec_from_value(self.velocity)) 2465 2466 2467class ParticleSpec(type_spec.BatchableTypeSpec): 2468 2469 def __init__(self, mass, velocity): 2470 self.shape = array_ops.broadcast_static_shape( 2471 mass.shape, velocity.shape) 2472 self.mass = mass 2473 self.velocity = velocity 2474 2475 def _serialize(self): 2476 return (self.mass, self.velocity) 2477 2478 @property 2479 def value_type(self): 2480 return Particle 2481 2482 @property 2483 def _component_specs(self): 2484 return (self.mass, self.velocity) 2485 2486 def _to_components(self, value): 2487 return (value.mass, value.velocity) 2488 2489 def _from_components(self, components): 2490 return Particle(*components) 2491 2492 def _pad_shape_to_full_rank(self, s): 2493 """Pad component shapes with 1's so all components have the same rank.""" 2494 return tensor_shape.TensorShape( 2495 [1] * (self.shape.ndims - s.ndims)).concatenate(s) 2496 2497 def _batch(self, batch_size): 2498 return ParticleSpec( 2499 mass=tensor_spec.TensorSpec( 2500 dtype=self.mass.dtype, 2501 shape=tensor_shape.TensorShape([batch_size]).concatenate( 2502 self._pad_shape_to_full_rank(self.mass.shape))), 2503 velocity=tensor_spec.TensorSpec( 2504 dtype=self.velocity.dtype, 2505 shape=tensor_shape.TensorShape([batch_size]).concatenate( 2506 self._pad_shape_to_full_rank(self.velocity.shape)))) 2507 2508 def _unbatch(self): 2509 return ParticleSpec( 2510 tensor_spec.TensorSpec(dtype=self.mass.dtype, 2511 shape=self.mass.shape[1:]), 2512 tensor_spec.TensorSpec(dtype=self.velocity.dtype, 2513 shape=self.velocity.shape[1:])) 2514 2515 def _to_tensor_list(self, value): 2516 return [array_ops.reshape( 2517 value.mass, 2518 self._pad_shape_to_full_rank(value.mass.shape)), 2519 array_ops.reshape( 2520 value.velocity, 2521 self._pad_shape_to_full_rank(value.velocity.shape))] 2522 2523 2524class CompositeTensorTest(PForTestCase, parameterized.TestCase): 2525 2526 @parameterized.parameters((None,), (3,)) 2527 def test_create_composite_inside_loop(self, parallel_iterations): 2528 num_particles = 10 2529 velocities = random_ops.random_uniform([num_particles]) 2530 particles = pfor_control_flow_ops.pfor( 2531 # Build a batch of particles all with the same mass. 2532 lambda i: Particle(mass=4., velocity=array_ops.gather(velocities, i)), 2533 num_particles, 2534 parallel_iterations=parallel_iterations) 2535 particles_mass, particles_velocity, velocities = self.evaluate( 2536 (particles.mass, particles.velocity, velocities)) 2537 self.assertAllEqual(particles_mass, 4. * np.ones([num_particles])) 2538 self.assertAllEqual(particles_velocity, velocities) 2539 2540 @parameterized.parameters((None,), (3,)) 2541 def test_composite_is_converted_to_batched_tensor( 2542 self, parallel_iterations): 2543 particles = pfor_control_flow_ops.pfor( 2544 lambda _: Particle(mass=random_ops.random_uniform([3]), # pylint: disable=g-long-lambda 2545 velocity=random_ops.random_uniform([5, 3])), 2546 4, 2547 parallel_iterations=parallel_iterations) 2548 # Naively batching the component shapes would give `[4, 3]` and `[4, 5, 3]` 2549 # which have no consistent broadcast shape. 2550 self.assertEqual(particles.mass.shape, [4, 1, 3]) 2551 self.assertAllEqual(particles.velocity.shape, [4, 5, 3]) 2552 2553 def test_vectorized_map_gathers_composite_tensors(self): 2554 particles = Particle(mass=[1., 2., 3., 4., 5.], 2555 velocity=[1., 2., 3., 4., 5.]) 2556 self.assertAllEqual( 2557 pfor_control_flow_ops.vectorized_map( 2558 lambda x: x.mass * x.velocity, particles), 2559 particles.mass * particles.velocity) 2560 2561 def test_vectorized_map_of_ragged_tensors(self): 2562 # Vmap should be able to handle ragged Tensors as long as they're not 2563 # *actually* ragged. 2564 ragged = ragged_tensor.RaggedTensor.from_uniform_row_length( 2565 ragged_tensor.RaggedTensor.from_row_lengths( 2566 values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 2567 row_lengths=[3, 3, 3, 3]), 2568 uniform_row_length=2) # Overall shape [2, 2, 3]. 2569 self.assertAllEqual( 2570 pfor_control_flow_ops.vectorized_map( 2571 lambda x: x.to_tensor(shape=[2, 3]), ragged), 2572 ragged.to_tensor(shape=[2, 2, 3])) 2573 2574 2575class ParsingTest(PForTestCase): 2576 2577 def test_decode_csv(self): 2578 csv_tensor = constant_op.constant([["1:2:3"], ["::"], ["7:8:9"]]) 2579 kwargs = {"record_defaults": [[10], [20], [30]], "field_delim": ":"} 2580 2581 def loop_fn(i): 2582 line = array_ops.gather(csv_tensor, i) 2583 return parsing_ops.decode_csv(line, **kwargs) 2584 2585 self._test_loop_fn(loop_fn, iters=3) 2586 2587 @test_util.run_v1_only("b/122612051") 2588 def test_parse_single_example(self): 2589 2590 def _int64_feature(*values): 2591 return feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=values)) 2592 2593 def _bytes_feature(*values): 2594 return feature_pb2.Feature( 2595 bytes_list=feature_pb2.BytesList( 2596 value=[v.encode("utf-8") for v in values])) 2597 2598 examples = constant_op.constant([ 2599 example_pb2.Example( 2600 features=feature_pb2.Features( 2601 feature={ 2602 "dense_int": _int64_feature(i), 2603 "dense_str": _bytes_feature(str(i)), 2604 "sparse_int": _int64_feature(i, i * 2, i * 4, i * 8), 2605 "sparse_str": _bytes_feature(*["abc"] * i) 2606 })).SerializeToString() for i in range(10) 2607 ]) 2608 2609 features = { 2610 "dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0), 2611 "dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""), 2612 "sparse_int": parsing_ops.VarLenFeature(dtypes.int64), 2613 "sparse_str": parsing_ops.VarLenFeature(dtypes.string), 2614 } 2615 2616 def loop_fn(i): 2617 example_proto = array_ops.gather(examples, i) 2618 f = parsing_ops.parse_single_example(example_proto, features) 2619 return f 2620 2621 pfor = pfor_control_flow_ops.pfor(loop_fn, iters=10) 2622 manual = parsing_ops.parse_example(examples, features) 2623 self.run_and_assert_equal(pfor, manual) 2624 2625 2626class PartitionedCallTest(PForTestCase): 2627 2628 def test_simple(self): 2629 2630 @def_function.function 2631 def f(x): 2632 return math_ops.square(x) + 1 2633 2634 z = random_ops.random_uniform([4]) 2635 2636 def loop_fn(i): 2637 return f(array_ops.gather(z, i)) 2638 2639 self._test_loop_fn(loop_fn, 4) 2640 2641 def test_nested_calls(self): 2642 2643 @def_function.function 2644 def inner(x): 2645 return math_ops.square(x) 2646 2647 @def_function.function 2648 def outer(y): 2649 return math_ops.reduce_sum(inner(y)) + 2 2650 2651 z = random_ops.random_uniform([4, 2]) 2652 2653 def loop_fn(i): 2654 return outer(array_ops.gather(z, i)) 2655 2656 self._test_loop_fn(loop_fn, 4) 2657 2658 def test_nested_calls_loop_fn_autograph(self): 2659 #TODO (@bhack) Do we need to extend the coverage? 2660 2661 def loop_fn(x): 2662 for y in range(array_ops.constant(3)): 2663 pass 2664 return math_ops.square(x) 2665 2666 @def_function.function 2667 def loop_fn_caller(): 2668 self._test_loop_fn(loop_fn, 4) 2669 2670 loop_fn_caller() 2671 2672 2673 def test_nested_definition(self): 2674 2675 @def_function.function 2676 def outer(y): 2677 2678 @def_function.function 2679 def inner(x): 2680 return math_ops.square(x) + 1 2681 2682 return math_ops.reduce_sum(inner(y)) + 2 2683 2684 z = random_ops.random_uniform([4, 2]) 2685 2686 def loop_fn(i): 2687 return outer(array_ops.gather(z, i)) 2688 2689 self._test_loop_fn(loop_fn, 4) 2690 2691 def test_gradients(self): 2692 2693 @def_function.function 2694 def f(x): 2695 return math_ops.square(x) + 1 2696 2697 z = random_ops.random_uniform([4, 2]) 2698 2699 def loop_fn(i): 2700 z_i = array_ops.gather(z, i) 2701 with backprop.GradientTape() as g: 2702 g.watch(z_i) 2703 out = f(z_i) 2704 return out, g.gradient(out, z_i) 2705 2706 self._test_loop_fn(loop_fn, 4) 2707 2708 def test_stateful_with_gradients(self): 2709 2710 z = random_ops.random_uniform([4, 2]) 2711 v = variables.Variable(z[0]) 2712 2713 @def_function.function 2714 def f(x): 2715 return math_ops.square(x) + v + 1 2716 2717 def loop_fn(i): 2718 z_i = array_ops.gather(z, i) 2719 with backprop.GradientTape() as g: 2720 g.watch(z_i) 2721 out = f(z_i) 2722 return out, g.gradient(out, z_i) 2723 2724 self._test_loop_fn(loop_fn, 4) 2725 2726 2727class SpectralTest(PForTestCase, parameterized.TestCase): 2728 2729 @parameterized.parameters( 2730 (fft_ops.fft,), 2731 (fft_ops.fft2d,), 2732 (fft_ops.fft3d,), 2733 (fft_ops.ifft,), 2734 (fft_ops.ifft2d,), 2735 (fft_ops.ifft3d,), 2736 ) 2737 def test_fft(self, op_func): 2738 shape = [2, 3, 4, 3, 4] 2739 x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape) 2740 2741 def loop_fn(i): 2742 x_i = array_ops.gather(x, i) 2743 return op_func(x_i) 2744 2745 self._test_loop_fn(loop_fn, 2) 2746 2747 @parameterized.parameters( 2748 (fft_ops.rfft,), 2749 (fft_ops.rfft2d,), 2750 (fft_ops.rfft3d,), 2751 ) 2752 @test.disable_with_predicate( 2753 pred=test.is_built_with_rocm, 2754 skip_message="Disable subtest on ROCm due to rocfft issues") 2755 def test_rfft(self, op_func): 2756 for dtype in (dtypes.float32, dtypes.float64): 2757 x = random_ops.random_uniform([2, 3, 4, 3, 4], dtype=dtype) 2758 2759 # pylint: disable=cell-var-from-loop 2760 def loop_fn(i): 2761 x_i = array_ops.gather(x, i) 2762 return op_func(x_i) 2763 2764 # pylint: enable=cell-var-from-loop 2765 2766 self._test_loop_fn(loop_fn, 2) 2767 2768 @parameterized.parameters( 2769 (fft_ops.irfft,), 2770 (fft_ops.irfft2d,), 2771 (fft_ops.irfft3d,), 2772 ) 2773 @test.disable_with_predicate( 2774 pred=test.is_built_with_rocm, 2775 skip_message="Disable subtest on ROCm due to rocfft issues") 2776 def test_irfft(self, op_func): 2777 if config.list_physical_devices("GPU"): 2778 # TODO(b/149957923): The test is flaky 2779 self.skipTest("b/149957923: irfft vectorization flaky") 2780 for dtype in (dtypes.complex64, dtypes.complex128): 2781 shape = [2, 3, 4, 3, 4] 2782 x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape) 2783 x = math_ops.cast(x, dtype=dtype) 2784 2785 # pylint: disable=cell-var-from-loop 2786 def loop_fn(i): 2787 x_i = array_ops.gather(x, i) 2788 return op_func(x_i) 2789 2790 # pylint: enable=cell-var-from-loop 2791 2792 self._test_loop_fn(loop_fn, 2) 2793 2794 2795class VariableTest(PForTestCase): 2796 2797 def test_create_variable_once(self): 2798 x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32) 2799 y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32) 2800 a_var = [] 2801 2802 def f(z): 2803 if not a_var: 2804 a_var.append(variables.Variable(lambda: y, name="a")) 2805 return math_ops.matmul(z, a_var[0] / 16) 2806 2807 pfor_control_flow_ops.vectorized_map(f, x) 2808 2809 @test_util.run_v2_only 2810 def test_create_variable_repeated(self): 2811 x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32) 2812 y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32) 2813 2814 def f(z): 2815 a_var = variables.Variable(lambda: y, name="a") / 4 2816 return math_ops.matmul(z, a_var / 16) 2817 2818 # Note that this error is only raised under v2 behavior. 2819 with self.assertRaisesRegex( 2820 ValueError, "singleton tf.Variable.*on the first call"): 2821 pfor_control_flow_ops.vectorized_map(f, x) 2822 2823 @test_util.run_all_in_graph_and_eager_modes 2824 def test_variable_shape(self): 2825 v = resource_variable_ops.ResourceVariable([1, 2]) 2826 2827 def loop_fn(_): 2828 return resource_variable_ops.variable_shape(v.handle) 2829 2830 self._test_loop_fn(loop_fn, 2) 2831 2832 @test_util.run_all_in_graph_and_eager_modes 2833 def test_variable_input(self): 2834 v = resource_variable_ops.ResourceVariable([1, 2]) 2835 self.evaluate(v.initializer) 2836 2837 def loop_fn(x): 2838 return x + 1 2839 2840 result = pfor_control_flow_ops.vectorized_map(loop_fn, v) 2841 expected_result = [2, 3] 2842 self.assertAllEqual(result, expected_result) 2843 2844 @test_util.run_all_in_graph_and_eager_modes 2845 def testStatelessCase(self): 2846 2847 def branch1(x): 2848 return x 2849 2850 def branch2(x): 2851 return x + 1 2852 2853 def branch3(x): 2854 return x + 2 2855 2856 x = constant_op.constant(10) 2857 elems = constant_op.constant([1, 0, 0, 0, 2, 1, 0, 2, 0, 1]) 2858 def loop_fn(z_i): 2859 return cond_v2.indexed_case( 2860 z_i, [lambda: branch1(x), lambda: branch2(x), lambda: branch3(x)]) 2861 2862 result = pfor_control_flow_ops.vectorized_map( 2863 loop_fn, elems, fallback_to_while_loop=False) 2864 2865 expected_result = [11, 10, 10, 10, 12, 11, 10, 12, 10, 11] 2866 self.assertAllEqual(result, expected_result) 2867 2868 @test_util.run_all_in_graph_and_eager_modes 2869 def testStatelessCaseUnstacked(self): 2870 2871 def branch1(x): 2872 return x + 1 2873 2874 def branch2(x): 2875 return x + 2 2876 2877 # Unstacked case input 2878 case_input = constant_op.constant(1) 2879 @def_function.function 2880 def function(z_i): 2881 return cond_v2.indexed_case(case_input, 2882 [lambda: branch1(z_i), lambda: branch2(z_i)]) 2883 2884 inputs = constant_op.constant([0, 1, 1, 0, 1, 0, 1, 0, 0]) 2885 2886 result = pfor_control_flow_ops.vectorized_map( 2887 function, inputs, fallback_to_while_loop=False) 2888 expected_result = [2, 3, 3, 2, 3, 2, 3, 2, 2] 2889 self.assertAllEqual(result, expected_result) 2890 2891if __name__ == "__main__": 2892 test.main() 2893