1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for pfor and for_loop.""" 16# pylint: disable=g-direct-tensorflow-import 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23import sys 24import time 25 26from absl.testing import parameterized 27import numpy as np 28 29from google.protobuf import text_format 30 31from tensorflow.core.example import example_pb2 32from tensorflow.core.example import feature_pb2 33from tensorflow.core.framework import graph_pb2 34from tensorflow.python.client import session 35from tensorflow.python.eager import backprop 36from tensorflow.python.eager import context 37from tensorflow.python.eager import def_function 38from tensorflow.python.framework import composite_tensor 39from tensorflow.python.framework import config 40from tensorflow.python.framework import constant_op 41from tensorflow.python.framework import dtypes 42from tensorflow.python.framework import importer 43from tensorflow.python.framework import indexed_slices 44from tensorflow.python.framework import ops 45from tensorflow.python.framework import sparse_tensor 46from tensorflow.python.framework import tensor_shape 47from tensorflow.python.framework import tensor_spec 48from tensorflow.python.framework import test_util 49from tensorflow.python.framework import type_spec 50from tensorflow.python.ops import array_ops 51from tensorflow.python.ops import bitwise_ops 52from tensorflow.python.ops import cond_v2 53from tensorflow.python.ops import control_flow_ops 54from tensorflow.python.ops import control_flow_v2_toggles 55from tensorflow.python.ops import data_flow_ops 56from tensorflow.python.ops import functional_ops 57from tensorflow.python.ops import gen_dataset_ops 58from tensorflow.python.ops import gen_list_ops 59from tensorflow.python.ops import gen_nn_ops 60from tensorflow.python.ops import gradient_checker_v2 61from tensorflow.python.ops import gradients as gradient_ops 62from tensorflow.python.ops import image_ops 63from tensorflow.python.ops import list_ops 64from tensorflow.python.ops import logging_ops 65from tensorflow.python.ops import manip_ops 66from tensorflow.python.ops import map_fn 67from tensorflow.python.ops import math_ops 68from tensorflow.python.ops import nn 69from tensorflow.python.ops import parsing_ops 70from tensorflow.python.ops import random_ops 71from tensorflow.python.ops import resource_variable_ops 72from tensorflow.python.ops import rnn 73from tensorflow.python.ops import rnn_cell 74from tensorflow.python.ops import stateless_random_ops 75from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import 76from tensorflow.python.ops import tensor_array_ops 77from tensorflow.python.ops import variables 78from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops 79from tensorflow.python.ops.parallel_for.test_util import PForTestCase 80from tensorflow.python.ops.ragged import ragged_tensor 81from tensorflow.python.ops.signal import fft_ops 82from tensorflow.python.platform import test 83from tensorflow.python.util import nest 84 85 86@test_util.run_all_in_graph_and_eager_modes 87@test_util.with_control_flow_v2 88class PForTest(PForTestCase): 89 90 def test_op_conversion_fallback_to_while_loop(self): 91 # Note that we used top_k op for this test. If a converter gets defined for 92 # it, we will need to find another op for which a converter has not been 93 # defined. 94 x = random_ops.random_uniform([3, 2, 4]) 95 96 def loop_fn(i): 97 x_i = array_ops.gather(x, i) 98 return nn.top_k(x_i) 99 100 with self.assertRaisesRegex(ValueError, "No pfor vectorization"): 101 self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False) 102 self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True) 103 104 def test_parallel_iterations(self): 105 for parallel_iterations in [2, 3, 8, 10]: 106 x = random_ops.random_uniform([8, 3]) 107 108 # pylint: disable=cell-var-from-loop 109 def loop_fn(i): 110 return array_ops.gather(x, i) 111 112 # pylint: enable=cell-var-from-loop 113 114 self._test_loop_fn(loop_fn, 8, parallel_iterations=parallel_iterations) 115 self._test_loop_fn( 116 loop_fn, 117 4 * constant_op.constant(2), 118 parallel_iterations=parallel_iterations) 119 120 def test_parallel_iterations_preserves_static_shape(self): 121 for parallel_iterations in [2, 3, 8, 10]: 122 x = pfor_control_flow_ops.pfor( 123 lambda _: random_ops.random_uniform([2, 3]), 124 8, 125 parallel_iterations=parallel_iterations) 126 self.assertAllEqual(x.shape, [8, 2, 3]) 127 128 def test_parallel_iterations_zero(self): 129 with self.assertRaisesRegex(ValueError, "positive integer"): 130 pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=0) 131 with self.assertRaisesRegex(TypeError, "positive integer"): 132 pfor_control_flow_ops.for_loop( 133 lambda i: 1, dtypes.int32, 8, parallel_iterations=0) 134 135 def test_parallel_iterations_one(self): 136 with self.assertRaisesRegex(ValueError, "Use for_loop instead"): 137 pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=1) 138 139 def test_vectorized_map(self): 140 141 def compute(x): 142 return math_ops.reduce_mean(x, axis=0, keepdims=True) 143 144 result = pfor_control_flow_ops.vectorized_map(compute, 145 array_ops.ones((10, 5, 3))) 146 self.run_and_assert_equal(result, array_ops.ones((10, 1, 3))) 147 148 def test_vectorized_map_with_dynamic_shape(self): 149 150 def compute(x): 151 return math_ops.reduce_mean(x, axis=0, keepdims=True) 152 153 x = array_ops.placeholder_with_default( 154 array_ops.ones((10, 5, 3)), shape=None) 155 result = pfor_control_flow_ops.vectorized_map(compute, x) 156 self.run_and_assert_equal(result, array_ops.ones((10, 1, 3))) 157 158 def test_where_shape(self): 159 @def_function.function 160 def f(): 161 a = constant_op.constant([[1.], [1.]]) 162 b = constant_op.constant([1.]) 163 result = pfor_control_flow_ops.vectorized_map( 164 lambda x: array_ops.where(x > 0, x, b), a) 165 return result.shape 166 167 self.assertAllEqual([2, 1], f()) 168 169 def test_vectorized_map_broadcasts_unit_dimensions(self): 170 convert_with_static_shape = ops.convert_to_tensor 171 convert_with_dynamic_shape = ( 172 lambda x: array_ops.placeholder_with_default(x, shape=None)) 173 174 for convert in (convert_with_static_shape, convert_with_dynamic_shape): 175 a = convert([3.1]) 176 b = convert([-2., 6., 9.]) 177 178 # One elem with leading unit dimension. 179 a_plus_1 = pfor_control_flow_ops.vectorized_map(lambda a: a + 1, a) 180 self.assertAllEqual(*self.evaluate((a_plus_1, a + 1))) 181 182 # Two elems, both with leading unit dimension. 183 a_plus_a = pfor_control_flow_ops.vectorized_map(sum, (a, a)) 184 self.assertAllEqual(*self.evaluate((a_plus_a, a + a))) 185 186 # Elem w/ unit dimension broadcast against elem with batch dim. 187 a_plus_b = pfor_control_flow_ops.vectorized_map(sum, (a, b)) 188 self.assertAllEqual(*self.evaluate((a_plus_b, a + b))) 189 190 def test_vectorized_map_example_1(self): 191 192 def outer_product(a): 193 return math_ops.tensordot(a, a, 0) 194 195 batch_size = 100 196 a = array_ops.ones((batch_size, 32, 32)) 197 c = pfor_control_flow_ops.vectorized_map(outer_product, a) 198 self.assertAllEqual((batch_size, 32, 32, 32, 32), c.shape) 199 200 def test_disable_tf_function(self): 201 def_function.run_functions_eagerly(True) 202 # vectorized_map should ignore disabling tf.functions 203 self.assertTrue(def_function.functions_run_eagerly()) 204 self.assertAllEqual([0, 1, 4, 9], 205 pfor_control_flow_ops.vectorized_map( 206 lambda x: x * x, math_ops.range(4))) 207 self.assertTrue(def_function.functions_run_eagerly()) 208 def_function.run_functions_eagerly(False) 209 210 211@test_util.run_all_in_graph_and_eager_modes 212class IndexedSlicesTest(PForTestCase): 213 214 def test_indexed_slices(self): 215 216 def loop_fn(i): 217 return indexed_slices.IndexedSlices( 218 indices=i, values=array_ops.reshape(i, [1]), dense_shape=[3, 1]) 219 220 self._test_loop_fn(loop_fn, 2) 221 222 def test_indexed_slices_components(self): 223 224 def loop_fn(i): 225 slices = indexed_slices.IndexedSlices( 226 indices=i, values=array_ops.reshape(i, [1]), dense_shape=[3, 1]) 227 # Note that returning the components inside the slice avoids 228 # densification, which may be more efficient. 229 return slices.values, slices.indices 230 231 self._test_loop_fn(loop_fn, 2) 232 233 234@test_util.run_all_in_graph_and_eager_modes 235class ReductionTest(PForTestCase): 236 237 def test_reduce(self): 238 239 def reduce_fn(p, q): 240 return math_ops.reduce_mean(p + q, axis=0) 241 242 x = random_ops.random_uniform([4, 3, 2]) 243 y = random_ops.random_uniform([4, 3, 2]) 244 245 def loop_fn(i, pfor_config): 246 x_i = array_ops.gather(x, i) 247 y_i = array_ops.gather(y, i) 248 reduced = pfor_config.reduce(reduce_fn, x_i, y_i) 249 return reduced + x_i 250 251 output = pfor_control_flow_ops.pfor(loop_fn, 4) 252 ans = reduce_fn(x, y) + x 253 output_val, ans_val = self.evaluate([output, ans]) 254 self.assertAllClose(ans_val, output_val) 255 256 def test_reduce_concat(self): 257 x = random_ops.random_uniform([8, 3]) 258 259 def loop_fn(i, pfor_config): 260 x_i = array_ops.gather(x, i) 261 vectorized_value = pfor_config.reduce_concat(x_i) 262 mean_value = math_ops.reduce_mean(vectorized_value, axis=0) 263 return x_i - mean_value 264 265 output = pfor_control_flow_ops.pfor(loop_fn, 8) 266 ans = x - math_ops.reduce_mean(x, axis=0) 267 output_val, ans_val = self.evaluate([output, ans]) 268 self.assertAllClose(ans_val, output_val) 269 270 def test_reduce_mean(self): 271 x = random_ops.random_uniform([8, 3]) 272 273 def loop_fn(i, pfor_config): 274 x_i = array_ops.gather(x, i) 275 return x_i - pfor_config.reduce_mean(x_i) 276 277 output = pfor_control_flow_ops.pfor(loop_fn, 8) 278 ans = x - math_ops.reduce_mean(x, axis=0) 279 output_val, ans_val = self.evaluate([output, ans]) 280 self.assertAllClose(ans_val, output_val) 281 282 def test_reduce_sum(self): 283 x = random_ops.random_uniform([8, 3]) 284 285 def loop_fn(i, pfor_config): 286 x_i = array_ops.gather(x, i) 287 return x_i - pfor_config.reduce_sum(x_i) 288 289 output = pfor_control_flow_ops.pfor(loop_fn, 8) 290 ans = x - math_ops.reduce_sum(x, axis=0) 291 output_val, ans_val = self.evaluate([output, ans]) 292 self.assertAllClose(ans_val, output_val) 293 294 def test_reduce_class(self): 295 x = random_ops.random_uniform([8, 3]) 296 297 class LoopFn(object): 298 299 def __init__(self): 300 pass 301 302 def __call__(self, i, pfor_config): 303 x_i = array_ops.gather(x, i) 304 return x_i - pfor_config.reduce_mean(x_i) 305 306 output = pfor_control_flow_ops.pfor(LoopFn(), 8) 307 ans = x - math_ops.reduce_mean(x, axis=0) 308 output_val, ans_val = self.evaluate([output, ans]) 309 self.assertAllClose(ans_val, output_val) 310 311 def test_reduce_functools_partial(self): 312 x = random_ops.random_uniform([8, 3]) 313 314 def fn(i, pfor_config, dummy=None): 315 del dummy 316 x_i = array_ops.gather(x, i) 317 return x_i - pfor_config.reduce_mean(x_i) 318 319 loop_fn = functools.partial(fn, dummy=1) 320 output = pfor_control_flow_ops.pfor(loop_fn, 8) 321 ans = x - math_ops.reduce_mean(x, axis=0) 322 output_val, ans_val = self.evaluate([output, ans]) 323 self.assertAllClose(ans_val, output_val) 324 325 def test_parallel_iterations(self): 326 x = random_ops.random_uniform([8, 3]) 327 328 def loop_fn(i, pfor_config): 329 x_i = array_ops.gather(x, i) 330 return pfor_config.reduce_sum(x_i) 331 332 with self.assertRaisesRegex(ValueError, 333 "parallel_iterations currently unsupported"): 334 pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2) 335 336 def test_var_loop_len(self): 337 if context.executing_eagerly(): 338 self.skipTest("Variable length not possible under eager execution.") 339 340 x = random_ops.random_uniform([8, 3]) 341 342 def loop_fn(i, pfor_config): 343 return pfor_config.reduce_sum(array_ops.gather(x, i)) 344 345 num_iters = array_ops.placeholder(dtypes.int32) 346 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 347 with self.cached_session() as sess: 348 sess.run(pfor, feed_dict={num_iters: 8}) 349 350 351@test_util.run_all_in_graph_and_eager_modes 352class BitwiseTest(PForTestCase): 353 354 def test_unary_cwise(self): 355 for op in [bitwise_ops.invert]: 356 x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32) 357 358 # pylint: disable=cell-var-from-loop 359 def loop_fn(i): 360 x1 = array_ops.gather(x, i) 361 return op(x1) 362 363 # pylint: enable=cell-var-from-loop 364 365 self._test_loop_fn(loop_fn, 3) 366 367 def test_binary_cwise(self): 368 binary_ops = [ 369 bitwise_ops.bitwise_and, 370 bitwise_ops.bitwise_or, 371 bitwise_ops.bitwise_xor, 372 bitwise_ops.left_shift, 373 bitwise_ops.right_shift, 374 ] 375 for op in binary_ops: 376 x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32) 377 y = random_ops.random_uniform([3, 5], maxval=10, dtype=dtypes.int32) 378 379 output_dtypes = [] 380 381 # pylint: disable=cell-var-from-loop 382 def loop_fn(i): 383 x1 = array_ops.gather(x, i) 384 y1 = array_ops.gather(y, i) 385 outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)] 386 del output_dtypes[:] 387 output_dtypes.extend(t.dtype for t in outputs) 388 return outputs 389 390 # pylint: enable=cell-var-from-loop 391 self._test_loop_fn(loop_fn, 3) 392 393 394@test_util.run_all_in_graph_and_eager_modes 395class ImageTest(PForTestCase): 396 397 def test_adjust_contrast(self): 398 images = random_ops.random_uniform([3, 2, 4, 4, 3]) 399 400 def loop_fn(i): 401 image = array_ops.gather(images, i) 402 return image_ops.adjust_contrast(image, 2.0) 403 404 self._test_loop_fn(loop_fn, 3) 405 406 def test_adjust_hue(self): 407 images = random_ops.random_uniform([3, 2, 4, 4, 3]) 408 409 def loop_fn(i): 410 image = array_ops.gather(images, i) 411 return image_ops.adjust_hue(image, .25) 412 413 self._test_loop_fn(loop_fn, 3) 414 415 def test_adjust_saturation(self): 416 images = random_ops.random_uniform([3, 2, 4, 4, 3]) 417 418 def loop_fn(i): 419 image = array_ops.gather(images, i) 420 return image_ops.adjust_saturation(image, 0.1) 421 422 self._test_loop_fn(loop_fn, 3) 423 424 425@test_util.run_all_in_graph_and_eager_modes 426class NNTest(PForTestCase): 427 428 def test_conv2d(self): 429 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 430 filt = random_ops.random_uniform([3, 3, 3, 7]) 431 432 def loop_fn(i): 433 x1 = array_ops.gather(x, i) 434 return nn.conv2d( 435 x1, filt, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC") 436 437 self._test_loop_fn(loop_fn, 3) 438 439 def test_conv2d_backprop_input(self): 440 x_shape = [2, 12, 12, 3] 441 filt = random_ops.random_uniform([3, 3, 3, 7]) 442 grad = random_ops.random_uniform([3, 2, 5, 5, 7]) 443 444 def loop_fn(i): 445 grad1 = array_ops.gather(grad, i) 446 return nn.conv2d_backprop_input( 447 x_shape, 448 filt, 449 grad1, 450 strides=[1, 2, 2, 1], 451 padding="VALID", 452 data_format="NHWC") 453 454 self._test_loop_fn(loop_fn, 3) 455 456 def test_conv2d_backprop_filter(self): 457 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 458 x_0 = array_ops.gather(x, 0) 459 filter_sizes = [3, 3, 3, 7] 460 grad = random_ops.random_uniform([3, 2, 5, 5, 7]) 461 462 def loop_fn(i): 463 x_i = array_ops.gather(x, i) 464 grad_i = array_ops.gather(grad, i) 465 return [ 466 nn.conv2d_backprop_filter( 467 inp, 468 filter_sizes, 469 grad_i, 470 strides=[1, 2, 2, 1], 471 padding="VALID", 472 data_format="NHWC") for inp in [x_i, x_0] 473 ] 474 475 self._test_loop_fn(loop_fn, 3) 476 477 def test_depthwise_conv2d_native(self): 478 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 479 filt = random_ops.random_uniform([3, 3, 3, 3, 2]) 480 481 def loop_fn(i): 482 x1 = array_ops.gather(x, i) 483 filt1 = array_ops.gather(filt, i) 484 return nn.depthwise_conv2d_native( 485 x1, filt1, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC") 486 487 self._test_loop_fn(loop_fn, 3) 488 489 def test_depthwise_conv2d_native_backprop_input(self): 490 x_shape = [2, 12, 12, 3] 491 filt = random_ops.random_uniform([3, 3, 3, 3, 2]) 492 grad = random_ops.random_uniform([3, 2, 5, 5, 6]) 493 494 def loop_fn(i): 495 grad1 = array_ops.gather(grad, i) 496 filt1 = array_ops.gather(filt, i) 497 return nn.depthwise_conv2d_native_backprop_input( 498 x_shape, 499 filt1, 500 grad1, 501 strides=[1, 2, 2, 1], 502 padding="VALID", 503 data_format="NHWC") 504 505 self._test_loop_fn(loop_fn, 3) 506 507 def test_depthwise_conv2d_native_backprop_filter(self): 508 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 509 filter_sizes = [3, 3, 3, 2] 510 grad = random_ops.random_uniform([3, 2, 5, 5, 6]) 511 512 def loop_fn(i): 513 x_i = array_ops.gather(x, i) 514 grad_i = array_ops.gather(grad, i) 515 return nn.depthwise_conv2d_native_backprop_filter( 516 x_i, 517 filter_sizes, 518 grad_i, 519 strides=[1, 2, 2, 1], 520 padding="VALID", 521 data_format="NHWC") 522 523 self._test_loop_fn(loop_fn, 3) 524 525 def test_depthwise_conv2d_native_nchw(self): 526 if not test_util.is_gpu_available(): 527 self.skipTest("NCHW only works on GPU") 528 x = random_ops.random_uniform([3, 2, 3, 12, 12]) 529 filt = random_ops.random_uniform([3, 3, 3, 3, 2]) 530 531 def loop_fn(i): 532 x1 = array_ops.gather(x, i) 533 filt1 = array_ops.gather(filt, i) 534 return nn.depthwise_conv2d_native( 535 x1, filt1, strides=[1, 1, 2, 2], padding="VALID", data_format="NCHW") 536 537 self._test_loop_fn(loop_fn, 3) 538 539 def test_depthwise_conv2d_native_backprop_input_nchw(self): 540 if not test_util.is_gpu_available(): 541 self.skipTest("NCHW only works on GPU") 542 x_shape = [2, 3, 12, 12] 543 filt = random_ops.random_uniform([3, 3, 3, 3, 2]) 544 grad = random_ops.random_uniform([3, 2, 6, 5, 5]) 545 546 def loop_fn(i): 547 grad1 = array_ops.gather(grad, i) 548 filt1 = array_ops.gather(filt, i) 549 return nn.depthwise_conv2d_native_backprop_input( 550 x_shape, 551 filt1, 552 grad1, 553 strides=[1, 1, 2, 2], 554 padding="VALID", 555 data_format="NCHW") 556 557 self._test_loop_fn(loop_fn, 3) 558 559 def test_depthwise_conv2d_native_backprop_filter_nchw(self): 560 if not test_util.is_gpu_available(): 561 self.skipTest("NCHW only works on GPU") 562 x = random_ops.random_uniform([3, 2, 3, 12, 12]) 563 filter_sizes = [3, 3, 3, 2] 564 grad = random_ops.random_uniform([3, 2, 6, 5, 5]) 565 566 def loop_fn(i): 567 x_i = array_ops.gather(x, i) 568 grad_i = array_ops.gather(grad, i) 569 return nn.depthwise_conv2d_native_backprop_filter( 570 x_i, 571 filter_sizes, 572 grad_i, 573 strides=[1, 1, 2, 2], 574 padding="VALID", 575 data_format="NCHW") 576 577 self._test_loop_fn(loop_fn, 3) 578 579 def test_roll(self): 580 x = random_ops.random_uniform([3, 6, 7]) 581 582 def loop_fn(i): 583 x_i = array_ops.gather(x, i) 584 return manip_ops.roll(x_i, 3, axis=1) 585 586 self._test_loop_fn(loop_fn, 3) 587 588 def test_ensure_shape(self): 589 x = random_ops.random_uniform([3, 6, 7]) 590 591 def loop_fn(i): 592 x_i = array_ops.gather(x, i) 593 return array_ops.ensure_shape(x_i, [6, 7]) 594 595 self._test_loop_fn(loop_fn, 3) 596 597 def test_loop_variant_roll_shift(self): 598 self.skipTest("TODO(b/191880259): re-enable once XLA compile times are " 599 "addressed.") 600 x = random_ops.random_uniform([3, 5, 6, 7]) 601 602 def loop_fn(i): 603 x_i = array_ops.gather(x, i) 604 return manip_ops.roll(x_i, [i - 2, -1, i], axis=[1, 2, 2]) 605 606 self._test_loop_fn(loop_fn, 3) 607 608 def test_loop_variant_roll_scalar_shift(self): 609 self.skipTest("TODO(b/191880259): re-enable once XLA compile times are " 610 "addressed.") 611 x = random_ops.random_uniform([5, 5, 6]) 612 613 def loop_fn(i): 614 x_i = array_ops.gather(x, i) 615 return manip_ops.roll(x_i, i, axis=0) 616 617 self._test_loop_fn(loop_fn, 5) 618 619 def test_avg_pool(self): 620 with backprop.GradientTape(persistent=True) as g: 621 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 622 g.watch(x) 623 ksize = [1, 3, 3, 1] 624 625 def loop_fn(i): 626 with g: 627 x1 = array_ops.gather(x, i) 628 output = nn.avg_pool( 629 x1, 630 ksize, 631 strides=[1, 2, 2, 1], 632 padding="VALID", 633 data_format="NHWC") 634 loss = nn.l2_loss(output) 635 return output, g.gradient(loss, x1) 636 637 self._test_loop_fn(loop_fn, 3) 638 639 def test_avg_pool3d(self): 640 with backprop.GradientTape(persistent=True) as g: 641 x = random_ops.random_uniform([5, 3, 7, 6, 6, 5]) 642 g.watch(x) 643 ksize = [1, 2, 2, 2, 1] 644 strides = [1, 2, 2, 2, 1] 645 646 def loop_fn(i): 647 with g: 648 x1 = array_ops.gather(x, i) 649 output = nn.avg_pool3d( 650 x1, ksize, strides=strides, padding="VALID", data_format="NDHWC") 651 loss = nn.l2_loss(output) 652 return output, g.gradient(loss, x1) 653 654 self._test_loop_fn(loop_fn, 3) 655 656 def test_max_pool(self): 657 with backprop.GradientTape(persistent=True) as g: 658 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 659 g.watch(x) 660 ksize = [1, 3, 3, 1] 661 strides = [1, 2, 2, 1] 662 663 def loop_fn(i): 664 with g: 665 x1 = array_ops.gather(x, i) 666 output = nn.max_pool( 667 x1, ksize, strides=strides, padding="VALID", data_format="NHWC") 668 loss = nn.l2_loss(output) 669 ones = array_ops.ones_like(output) 670 g.watch(ones) 671 grad = g.gradient(loss, x1, output_gradients=ones) 672 grad_grad = g.gradient(grad, ones) 673 return output, grad, grad_grad 674 675 self._test_loop_fn(loop_fn, 3) 676 677 def test_max_pool_v2(self): 678 with backprop.GradientTape(persistent=True) as g: 679 x = random_ops.random_uniform([3, 2, 12, 12, 3]) 680 g.watch(x) 681 ksize = [1, 3, 3, 1] 682 strides = [1, 2, 2, 1] 683 684 def loop_fn(i): 685 with g: 686 x1 = array_ops.gather(x, i) 687 output = gen_nn_ops.max_pool_v2( 688 x1, ksize, strides=strides, padding="VALID", data_format="NHWC") 689 loss = nn.l2_loss(output) 690 ones = array_ops.ones_like(output) 691 g.watch(ones) 692 grad = g.gradient(loss, x1, output_gradients=ones) 693 grad_grad = g.gradient(grad, ones) 694 return output, grad, grad_grad 695 696 self._test_loop_fn(loop_fn, 3) 697 698 def test_max_pool3d(self): 699 with backprop.GradientTape(persistent=True) as g: 700 x = random_ops.random_uniform([3, 3, 2, 12, 12, 3]) 701 g.watch(x) 702 ksize = [1, 1, 3, 3, 1] 703 strides = [1, 1, 2, 2, 1] 704 705 def loop_fn(i): 706 with g: 707 x1 = array_ops.gather(x, i) 708 output = nn.max_pool3d( 709 x1, ksize, strides=strides, padding="VALID", data_format="NDHWC") 710 loss = nn.l2_loss(output) 711 ones = array_ops.ones_like(output) 712 g.watch(ones) 713 grad = g.gradient(loss, x1, output_gradients=ones) 714 grad_grad = g.gradient(grad, ones) 715 return output, grad, grad_grad 716 717 self._test_loop_fn(loop_fn, 3) 718 719 def test_fused_batch_norm(self): 720 data_formats = ["NHWC"] 721 if test.is_gpu_available(): 722 data_formats.append("NCHW") 723 for is_training in (True, False): 724 for data_format in data_formats: 725 with backprop.GradientTape(persistent=True) as g: 726 if data_format == "NCHW": 727 x = random_ops.random_uniform([3, 1, 2, 5, 5]) 728 else: 729 x = random_ops.random_uniform([3, 1, 5, 5, 2]) 730 g.watch(x) 731 scale = random_ops.random_uniform([2]) 732 g.watch(scale) 733 offset = random_ops.random_uniform([2]) 734 g.watch(offset) 735 mean = None if is_training else random_ops.random_uniform([2]) 736 variance = None if is_training else random_ops.random_uniform([2]) 737 738 # pylint: disable=cell-var-from-loop 739 def loop_fn(i): 740 with g: 741 x1 = array_ops.gather(x, i) 742 outputs = nn.fused_batch_norm( 743 x1, 744 scale, 745 offset, 746 mean=mean, 747 variance=variance, 748 epsilon=0.01, 749 data_format=data_format, 750 is_training=is_training) 751 outputs = list(outputs) 752 # We only test the first value of outputs when is_training is 753 # False. It looks like CPU and GPU have different outputs for 754 # batch_mean and batch_variance for this case. 755 if not is_training: 756 outputs[1] = constant_op.constant(0.) 757 outputs[2] = constant_op.constant(0.) 758 loss = nn.l2_loss(outputs[0]) 759 if is_training: 760 gradients = g.gradient(loss, [x1, scale, offset]) 761 else: 762 gradients = [constant_op.constant(0.)] * 3 763 return outputs + gradients 764 765 # pylint: enable=cell-var-from-loop 766 767 self._test_loop_fn(loop_fn, 3) 768 769 def test_log_softmax(self): 770 logits = random_ops.random_uniform([3, 2, 4]) 771 772 def loop_fn(i): 773 logits_i = array_ops.gather(logits, i) 774 return (nn.log_softmax(logits_i), nn.log_softmax(logits_i, axis=0), 775 nn.log_softmax(logits_i, axis=-1)) 776 777 self._test_loop_fn(loop_fn, 3) 778 779 def test_softmax(self): 780 logits = random_ops.random_uniform([3, 2, 4]) 781 782 def loop_fn(i): 783 logits_i = array_ops.gather(logits, i) 784 return (nn.softmax(logits_i), nn.softmax(logits_i, axis=0), 785 nn.softmax(logits_i, axis=-1)) 786 787 self._test_loop_fn(loop_fn, 3) 788 789 def test_softmax_cross_entropy_with_logits(self): 790 with backprop.GradientTape(persistent=True) as g: 791 logits = random_ops.random_uniform([3, 2, 4]) 792 g.watch(logits) 793 labels = random_ops.random_uniform([3, 2, 4]) 794 labels /= math_ops.reduce_sum(labels, axis=[2], keepdims=True) 795 796 def loop_fn(i): 797 with g: 798 logits_i = array_ops.gather(logits, i) 799 labels_i = array_ops.gather(labels, i) 800 loss = nn.softmax_cross_entropy_with_logits( 801 labels=labels_i, logits=logits_i) 802 total_loss = math_ops.reduce_sum(loss) 803 return loss, g.gradient(total_loss, logits_i) 804 805 self._test_loop_fn(loop_fn, 3) 806 807 def test_sparse_softmax_cross_entropy_with_logits(self): 808 logits = random_ops.random_uniform([3, 2, 4]) 809 labels = random_ops.random_uniform( 810 shape=[3, 2], maxval=4, dtype=dtypes.int32) 811 812 def loop_fn(i): 813 logits_i = array_ops.gather(logits, i) 814 labels_i = array_ops.gather(labels, i) 815 loss = nn.sparse_softmax_cross_entropy_with_logits( 816 labels=labels_i, logits=logits_i) 817 return loss 818 819 self._test_loop_fn(loop_fn, 3) 820 821 822class RandomTest(PForTestCase): 823 824 # The random values generated in the two implementations are not guaranteed to 825 # match. So we only check the returned shapes. 826 def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): 827 outputs = self._run_targets(targets1, targets2) 828 n = len(outputs) // 2 829 for i in range(n): 830 self.assertAllEqual(outputs[i].shape, outputs[i + n].shape) 831 832 def test_random_uniform(self): 833 834 def loop_fn(_): 835 return random_ops.random_uniform([3]) 836 837 self._test_loop_fn(loop_fn, 5) 838 839 def test_random_uniform_int(self): 840 841 def loop_fn(_): 842 return random_ops.random_uniform([3], maxval=1, dtype=dtypes.int32) 843 844 self._test_loop_fn(loop_fn, 5) 845 846 def test_random_standard_normal(self): 847 848 def loop_fn(_): 849 return random_ops.random_normal([3]) 850 851 self._test_loop_fn(loop_fn, 5) 852 853 def test_truncated_normal(self): 854 855 def loop_fn(_): 856 return random_ops.truncated_normal([3]) 857 858 self._test_loop_fn(loop_fn, 5) 859 860 def test_random_gamma_invariant_alpha(self): 861 862 def loop_fn(_): 863 return random_ops.random_gamma([3], alpha=[0.5]) 864 865 self._test_loop_fn(loop_fn, 5) 866 867 def test_random_gamma_varying_alpha(self): 868 alphas = math_ops.exp(random_ops.random_normal([5, 3, 2])) 869 870 def loop_fn(i): 871 alphas_i = array_ops.gather(alphas, i) 872 # Test both scalar and non-scalar params and shapes. 873 return (random_ops.random_gamma(alpha=alphas_i[0, 0], shape=[]), 874 random_ops.random_gamma(alpha=alphas_i, shape=[]), 875 random_ops.random_gamma(alpha=alphas_i[0, 0], shape=[3]), 876 random_ops.random_gamma(alpha=alphas_i, shape=[3])) 877 878 self._test_loop_fn(loop_fn, 5) 879 880 def test_random_poisson_v2_invariant_rate(self): 881 882 def loop_fn(_): 883 return random_ops.random_poisson(lam=[1.3], shape=[3]) 884 885 self._test_loop_fn(loop_fn, 5) 886 887 def test_random_poisson_v2_varying_rate(self): 888 rates = math_ops.exp(random_ops.random_normal([5, 3, 2])) 889 890 def loop_fn(i): 891 rates_i = array_ops.gather(rates, i) 892 # Test both scalar and non-scalar params and shapes. 893 return (random_ops.random_poisson(lam=rates_i[0, 0], shape=[]), 894 random_ops.random_poisson(lam=rates_i, shape=[]), 895 random_ops.random_poisson(lam=rates_i[0, 0], shape=[3]), 896 random_ops.random_poisson(lam=rates_i, shape=[3])) 897 898 self._test_loop_fn(loop_fn, 5) 899 900 def test_random_multinomial_invariant_logits(self): 901 902 def loop_fn(_): 903 return random_ops.categorical(logits=[[1., -1.]], num_samples=3) 904 905 self._test_loop_fn(loop_fn, 5) 906 907 def test_random_multinomial_varying_logits(self): 908 logits = random_ops.random_normal([5, 3, 2]) 909 910 def loop_fn(i): 911 logits_i = array_ops.gather(logits, i) 912 return random_ops.categorical(logits_i, num_samples=3) 913 914 self._test_loop_fn(loop_fn, 5) 915 916 917class StatelessRandomTest(PForTestCase): 918 919 # This test currently only tests that the vectorized and non-vectorized 920 # outputs have same shapes. This is needed since under XLA compilation, 921 # stateless random numbers can generate different random numbers. 922 # TODO(agarwal): switch to checking for actual values matching once 923 # b/149402339 is resolved. 924 def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): 925 outputs = self._run_targets(targets1, targets2) 926 n = len(outputs) // 2 927 for i in range(n): 928 self.assertAllEqual(outputs[i].shape, outputs[i + n].shape) 929 930 # TODO(agarwal): add tests for other random functions 931 def test_multinomial(self): 932 seeds = [[1, 2], [3, 4]] 933 logits = random_ops.random_uniform([2, 3, 4]) 934 935 def loop_fn(i): 936 logits_0 = array_ops.gather(logits, 0) 937 logits_i = array_ops.gather(logits, i) 938 seeds_0 = array_ops.gather(seeds, 0) 939 seeds_i = array_ops.gather(seeds, i) 940 return (stateless_random_ops.stateless_categorical( 941 logits=logits_i, num_samples=3, seed=seeds_i), 942 stateless_random_ops.stateless_categorical( 943 logits=logits_i, num_samples=3, seed=seeds_0), 944 stateless_random_ops.stateless_categorical( 945 logits=logits_0, num_samples=3, seed=seeds_i), 946 stateless_random_ops.stateless_categorical( 947 logits=logits_0, num_samples=3, seed=seeds_0)) 948 949 self._test_loop_fn(loop_fn, 2) 950 951 952class LoggingTest(PForTestCase): 953 954 def test_print(self): 955 x = random_ops.random_uniform([3, 5]) 956 957 def loop_fn(i): 958 x1 = array_ops.gather(x, i) 959 return logging_ops.Print( 960 x1, [x1, "x1", array_ops.shape(x1)], summarize=10) 961 962 self._test_loop_fn(loop_fn, 3) 963 964 def test_print_v2(self): 965 x = constant_op.constant([1, 2, 3]) 966 967 def loop_fn(i): 968 x1 = array_ops.gather(x, i) 969 with ops.control_dependencies([ 970 logging_ops.print_v2( 971 x1, "x1", array_ops.shape(x1), summarize=10)]): 972 return array_ops.identity(x1) 973 974 self._test_loop_fn(loop_fn, 3) 975 976 with self.captureWritesToStream(sys.stderr) as printed: 977 self.evaluate(pfor_control_flow_ops.pfor(loop_fn, 3)) 978 self.assertIn("[1 2 3] x1 []", printed.contents()) 979 980 def test_assert(self): 981 982 def loop_fn(i): 983 return control_flow_ops.Assert(i < 10, [i, [10], [i + 1]]) 984 985 # TODO(agarwal): make this work with for_loop. 986 with session.Session() as sess: 987 sess.run(pfor_control_flow_ops.pfor(loop_fn, 3)) 988 sess.run(pfor_control_flow_ops.pfor( 989 lambda i, pfor_config: loop_fn(i), 3)) 990 991 992class TensorArrayTest(PForTestCase): 993 994 def setUp(self): 995 self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() 996 control_flow_v2_toggles.disable_control_flow_v2() 997 super(TensorArrayTest, self).setUp() 998 999 def tearDown(self): 1000 if self._enabled: 1001 control_flow_v2_toggles.enable_control_flow_v2() 1002 super(TensorArrayTest, self).tearDown() 1003 1004 @test_util.run_v1_only("b/122612051") 1005 def test_create_outside_and_read(self): 1006 1007 ta = tensor_array_ops.TensorArray( 1008 dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1) 1009 1010 def loop_fn(i): 1011 return ta.read(i), ta.read(0) 1012 1013 self._test_loop_fn(loop_fn, 2) 1014 1015 @test_util.run_v1_only("b/122612051") 1016 def test_create_outside_and_gather(self): 1017 1018 ta = tensor_array_ops.TensorArray( 1019 dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1) 1020 1021 def loop_fn(i): 1022 return ta.gather([i]), ta.gather([0, 1]) 1023 1024 self._test_loop_fn(loop_fn, 2) 1025 1026 @test_util.run_v1_only("b/122612051") 1027 def test_create_outside_and_write_and_scatter(self): 1028 1029 t = tensor_array_ops.TensorArray(dtypes.int32, 10, clear_after_read=False) 1030 handle = t.handle 1031 1032 def loop_fn(i): 1033 ta = t.write(i + 2, 2 * i).write(i, 5) 1034 ta = ta.scatter([4 + i], [4]).scatter([6 + i, 8 + i], [6 + i, 8 + i]) 1035 return ta.flow 1036 1037 t1 = pfor_control_flow_ops.pfor(loop_fn, iters=2) 1038 out1 = tensor_array_ops.TensorArray( 1039 dtypes.int32, handle=handle, flow=t1[-1]).stack() 1040 output1 = self._run_targets(out1) 1041 1042 t2 = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, iters=2) 1043 out2 = tensor_array_ops.TensorArray( 1044 dtypes.int32, handle=handle, flow=t2[-1]).stack() 1045 output2 = self._run_targets(out2) 1046 self.assertAllClose(output2, output1) 1047 1048 @test_util.run_v1_only("b/122612051") 1049 def test_create_inside_and_write(self): 1050 1051 def loop_fn(i): 1052 # TODO(agarwal): switching the order of writes to ta1 does not work. 1053 ta1 = tensor_array_ops.TensorArray(dtypes.int32, 2).write(0, 1054 i).write(1, 1) 1055 ta2 = tensor_array_ops.TensorArray(dtypes.int32, 1).write(0, 1) 1056 return ta1.stack(), ta2.stack() 1057 1058 self._test_loop_fn(loop_fn, 3) 1059 1060 @test_util.run_v1_only("b/122612051") 1061 def test_create_inside_and_scatter(self): 1062 1063 def loop_fn(i): 1064 # TODO(agarwal): switching the order of scatter to ta1 does not work. 1065 ta1 = tensor_array_ops.TensorArray(dtypes.int32, 1066 2).scatter([0], 1067 [[i, 2]]).scatter([1], 1068 [[1, 2]]) 1069 ta2 = tensor_array_ops.TensorArray(dtypes.int32, 1070 2).scatter([0], [3]).scatter([1], [4]) 1071 return ta1.stack(), ta2.stack() 1072 1073 self._test_loop_fn(loop_fn, 3) 1074 1075 @test_util.run_v1_only("b/122612051") 1076 def test_create_inside_and_read(self): 1077 1078 def loop_fn(i): 1079 ta1 = tensor_array_ops.TensorArray( 1080 dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1) 1081 ta2 = tensor_array_ops.TensorArray( 1082 dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2) 1083 # TODO(agarwal): ta1.read(i) currently is not supported. 1084 return ta1.read(0), ta2.read(0), ta2.read(i) 1085 1086 self._test_loop_fn(loop_fn, 2) 1087 1088 @test_util.run_v1_only("b/122612051") 1089 def test_create_inside_and_gather(self): 1090 1091 def loop_fn(i): 1092 ta1 = tensor_array_ops.TensorArray( 1093 dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1) 1094 ta2 = tensor_array_ops.TensorArray( 1095 dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2) 1096 # TODO(agarwal): ta1.read(i) currently is not supported. 1097 return ta1.gather([0, 1]), ta2.gather([0, 1]), ta2.gather([i]) 1098 1099 self._test_loop_fn(loop_fn, 2) 1100 1101 @test_util.run_v1_only("b/122612051") 1102 def test_grad(self): 1103 x = random_ops.random_uniform([3, 2]) 1104 ta = tensor_array_ops.TensorArray( 1105 dtypes.float32, 3, clear_after_read=False).unstack(x) 1106 y = math_ops.square(ta.stack()) 1107 1108 def loop_fn(i): 1109 y_i = array_ops.gather(y, i) 1110 grad = gradient_ops.gradients(y_i, x)[0] 1111 return array_ops.gather(grad, i) 1112 1113 t1 = pfor_control_flow_ops.pfor(loop_fn, iters=3) 1114 # y = x * x. Hence dy/dx = 2 * x. 1115 actual_grad = 2.0 * x 1116 with session.Session() as sess: 1117 actual_grad, computed_grad = sess.run([t1, actual_grad]) 1118 self.assertAllClose(actual_grad, computed_grad) 1119 1120 1121@test_util.run_all_in_graph_and_eager_modes 1122class TensorListTest(PForTestCase): 1123 1124 def test_create_outside_and_write(self): 1125 handle1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1126 handle2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1127 1128 def loop_fn(i): 1129 h1 = list_ops.tensor_list_set_item(handle1, 0, i) 1130 h1 = list_ops.tensor_list_set_item(h1, 1, 1) 1131 h2 = list_ops.tensor_list_set_item(handle2, 0, 1) 1132 return (list_ops.tensor_list_stack(h1, dtypes.int32), 1133 list_ops.tensor_list_stack(h2, dtypes.int32)) 1134 1135 self._test_loop_fn(loop_fn, 3) 1136 1137 def _make_graph_def(self, text): 1138 ret = graph_pb2.GraphDef() 1139 text_format.Parse(text, ret) 1140 return ret 1141 1142 def test_no_fallback_with_internal_stacking(self): 1143 1144 # Create an op (really a function) that pfor definitely does not have a 1145 # converter for. Assumes pfor does not start looking up function definitions 1146 # for op-type-is-function-name calls. 1147 @def_function.function 1148 def opaque_list_fetch(x): 1149 array_ops.identity(x) 1150 return list_ops.tensor_list_get_item(x, 0, dtypes.int32) 1151 1152 external_handle = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1153 opaque_list_fetch_concrete = opaque_list_fetch.get_concrete_function( 1154 external_handle) 1155 opaque_list_fetch_name = opaque_list_fetch_concrete.name 1156 1157 def loop_fn(i): 1158 h1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1159 h1 = list_ops.tensor_list_set_item(h1, 0, i) 1160 opaque_list_fetch_concrete.add_to_graph() 1161 graph_def = self._make_graph_def(""" 1162 node { name: 'x' op: 'Placeholder' 1163 attr { key: 'dtype' value { type: DT_FLOAT } }} 1164 node { name: 'fn' op: '""" + opaque_list_fetch_name.decode() 1165 + """' input: 'x:0' }""") 1166 return importer.import_graph_def( 1167 graph_def, 1168 input_map={"x:0": h1}, 1169 return_elements=["fn"], 1170 name="import")[0].outputs[0] 1171 1172 with self.assertRaisesRegex(ValueError, "No pfor vectorization"): 1173 self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False) 1174 with self.assertRaisesRegex(ValueError, "No pfor vectorization"): 1175 self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True) 1176 1177 def test_create_inside_and_write(self): 1178 1179 def loop_fn(i): 1180 h1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1181 h1 = list_ops.tensor_list_set_item(h1, 0, i) 1182 h1 = list_ops.tensor_list_set_item(h1, 1, 1) 1183 h2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1184 h2 = list_ops.tensor_list_set_item(h2, 0, 1) 1185 return (list_ops.tensor_list_stack(h1, dtypes.int32), 1186 list_ops.tensor_list_stack(h2, dtypes.int32)) 1187 1188 self._test_loop_fn(loop_fn, 3) 1189 1190 def test_create_outside_and_read(self): 1191 handle = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1192 handle = list_ops.tensor_list_set_item(handle, 0, 0) 1193 handle = list_ops.tensor_list_set_item(handle, 1, 1) 1194 1195 def loop_fn(i): 1196 return (list_ops.tensor_list_get_item(handle, i, dtypes.int32), 1197 list_ops.tensor_list_get_item(handle, 0, dtypes.int32), 1198 list_ops.tensor_list_length(handle), 1199 list_ops.tensor_list_element_shape(handle, dtypes.int32), 1200 list_ops.tensor_list_element_shape(handle, dtypes.int64)) 1201 1202 self._test_loop_fn(loop_fn, 2) 1203 1204 def test_create_inside_and_read(self): 1205 1206 def loop_fn(i): 1207 handle = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1208 handle = list_ops.tensor_list_set_item(handle, 0, i) 1209 handle = list_ops.tensor_list_set_item(handle, 1, 1) 1210 return (list_ops.tensor_list_get_item(handle, 0, dtypes.int32), 1211 list_ops.tensor_list_get_item(handle, i, dtypes.int32), 1212 list_ops.tensor_list_length(handle), 1213 list_ops.tensor_list_element_shape(handle, dtypes.int32), 1214 list_ops.tensor_list_element_shape(handle, dtypes.int64)) 1215 1216 self._test_loop_fn(loop_fn, 2) 1217 1218 def test_create_outside_and_push_back(self): 1219 h = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1220 1221 def loop_fn(i): 1222 handle = list_ops.tensor_list_push_back(h, [i, 2]) 1223 handle = list_ops.tensor_list_push_back(handle, [1, 2]) 1224 handle = list_ops.tensor_list_push_back(handle, [1, 2]) 1225 return list_ops.tensor_list_stack(handle, dtypes.int32) 1226 1227 self._test_loop_fn(loop_fn, 3) 1228 1229 def test_create_inside_and_push_back(self): 1230 1231 def loop_fn(i): 1232 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1233 handle = list_ops.tensor_list_push_back(handle, [i, 2]) 1234 handle = list_ops.tensor_list_push_back(handle, [1, 2]) 1235 return list_ops.tensor_list_stack(handle, dtypes.int32) 1236 1237 self._test_loop_fn(loop_fn, 3) 1238 1239 def test_pop_back_no_shape(self): 1240 1241 def loop_fn(i): 1242 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1243 handle = list_ops.tensor_list_push_back(handle, [1, 2]) 1244 handle = list_ops.tensor_list_push_back(handle, [i, 2]) 1245 handle, tensor = list_ops.tensor_list_pop_back(handle, dtypes.int32) 1246 return tensor, list_ops.tensor_list_stack(handle, dtypes.int32) 1247 1248 self._test_loop_fn(loop_fn, 3) 1249 1250 def test_pop_back_no_shape_capture(self): 1251 h = list_ops.tensor_list_reserve([2], 1, dtypes.int32) 1252 h = list_ops.tensor_list_push_back(h, [1, 2]) 1253 1254 def loop_fn(i): 1255 handle, tensor = list_ops.tensor_list_pop_back(h, dtypes.int32) 1256 handle = list_ops.tensor_list_push_back(handle, [1, i]) 1257 return tensor, list_ops.tensor_list_stack(handle, dtypes.int32) 1258 1259 self._test_loop_fn(loop_fn, 3) 1260 1261 def test_pop_back_with_shape(self): 1262 1263 @def_function.function 1264 def loop_fn(i): 1265 with backprop.GradientTape() as tape: 1266 handle = list_ops.tensor_list_reserve(None, 1, dtypes.float32) 1267 x = math_ops.cast(i, dtypes.float32)[None] 1268 tape.watch(x) 1269 handle = list_ops.tensor_list_push_back(handle, x) 1270 stacked = list_ops.tensor_list_stack(handle, dtypes.float32) 1271 list_grad = tape.gradient(stacked, x, x) 1272 self.assertEqual("TensorListPopBack", list_grad.op.type) 1273 return list_grad, stacked, list_grad.op.inputs[1] 1274 1275 self._test_loop_fn(loop_fn, 3) 1276 1277 def test_create_outside_and_scatter(self): 1278 h = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1279 1280 def loop_fn(i): 1281 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h) 1282 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1283 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1284 return list_ops.tensor_list_stack(handle, dtypes.int32) 1285 1286 self._test_loop_fn(loop_fn, 3) 1287 1288 def test_create_inside_and_scatter(self): 1289 1290 def loop_fn(i): 1291 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1292 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle) 1293 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1294 return list_ops.tensor_list_stack(handle, dtypes.int32) 1295 1296 self._test_loop_fn(loop_fn, 3) 1297 1298 def test_loop_variant_scatter_indices(self): 1299 1300 def loop_fn(i): 1301 handle = list_ops.tensor_list_reserve([2], 10, dtypes.int32) 1302 handle = list_ops.tensor_list_scatter( 1303 [[1, i], [i + 1, 2]], 1304 [i, i + 5], input_handle=handle) 1305 return list_ops.tensor_list_stack(handle, dtypes.int32) 1306 1307 self._test_loop_fn(loop_fn, 5) 1308 1309 def test_loop_variant_scatter_duplicate_indices(self): 1310 if test_util.is_gpu_available(): 1311 self.skipTest( 1312 "Flaky in some GPU configurations due to TensorScatterNdUpdate " 1313 "nondeterminism.") 1314 1315 def loop_fn(i): 1316 handle = list_ops.tensor_list_reserve([2], 10, dtypes.int32) 1317 handle = list_ops.tensor_list_scatter( 1318 [[1, i], [1, i + 1], [i + 2, 3]], 1319 [i, i, i + 2], input_handle=handle) 1320 return list_ops.tensor_list_stack(handle, dtypes.int32) 1321 1322 self._test_loop_fn(loop_fn, 5) 1323 1324 def test_create_outside_and_gather(self): 1325 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1326 handle = list_ops.tensor_list_scatter([[2, 3]], [0], input_handle=handle) 1327 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1328 1329 def loop_fn(i): 1330 return (list_ops.tensor_list_gather(handle, [0, 1], dtypes.int32), 1331 list_ops.tensor_list_gather(handle, [i], dtypes.int32)) 1332 1333 self._test_loop_fn(loop_fn, 2) 1334 1335 def test_create_inside_and_gather(self): 1336 1337 def loop_fn(i): 1338 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1339 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle) 1340 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1341 return (list_ops.tensor_list_gather(handle, [0, 1], dtypes.int32), 1342 list_ops.tensor_list_gather(handle, [i], dtypes.int32)) 1343 1344 self._test_loop_fn(loop_fn, 2) 1345 1346 def test_create_inside_and_concat(self): 1347 1348 def loop_fn(i): 1349 handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1350 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle) 1351 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1352 return gen_list_ops.tensor_list_concat_v2( 1353 handle, 1354 element_dtype=dtypes.int32, 1355 element_shape=[2], 1356 leading_dims=[]) 1357 1358 output = pfor_control_flow_ops.pfor(loop_fn, 2) 1359 self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0]) 1360 self.assertAllClose([[2, 2], [2, 2]], output[1]) 1361 1362 def test_create_outside_and_concat(self): 1363 h = list_ops.tensor_list_reserve([2], 2, dtypes.int32) 1364 1365 def loop_fn(i): 1366 handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h) 1367 handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) 1368 return gen_list_ops.tensor_list_concat_v2( 1369 handle, 1370 element_dtype=dtypes.int32, 1371 element_shape=[2], 1372 leading_dims=[]) 1373 1374 output = pfor_control_flow_ops.pfor(loop_fn, 2) 1375 self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0]) 1376 self.assertAllClose([[2, 2], [2, 2]], output[1]) 1377 1378 def test_tensor_list_from_tensor(self): 1379 t = random_ops.random_uniform([2, 3, 4]) 1380 1381 def loop_fn(i): 1382 handle = list_ops.tensor_list_from_tensor(array_ops.gather(t, i), [4]) 1383 return list_ops.tensor_list_stack(handle, t.dtype) 1384 1385 self._test_loop_fn(loop_fn, 2) 1386 1387 @test_util.enable_control_flow_v2 1388 def test_tensor_list_reserve_while_loop(self): 1389 # Here a loop invariant TensorList is captured by a while_loop, which then 1390 # performs loop dependent operations on it, resulting in a loop variant 1391 # output. This forces stacking of the variant handle captured by the 1392 # while_loop. 1393 # We handle this particular case by forcing vectorization of 1394 # TensorListReserve operation. 1395 1396 def loop_fn(i): 1397 handle = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1398 _, out_handle = control_flow_ops.while_loop( 1399 lambda j, _: j < 2, lambda j, h: 1400 (j + 1, list_ops.tensor_list_set_item(h, j, i)), (0, handle)) 1401 return list_ops.tensor_list_stack(out_handle, dtypes.int32) 1402 1403 self._test_loop_fn(loop_fn, 2) 1404 1405 @test_util.enable_control_flow_v2 1406 def test_tensor_list_while_loop_stacked_cond_stacked_list(self): 1407 1408 def loop_fn(i): 1409 handle = list_ops.tensor_list_from_tensor([20, 21, 22, 23, i], []) 1410 _, out_handle = control_flow_ops.while_loop( 1411 lambda j, _: j < i, 1412 lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)), 1413 (0, handle)) 1414 return list_ops.tensor_list_stack(out_handle, dtypes.int32) 1415 1416 self._test_loop_fn(loop_fn, 5) 1417 1418 @test_util.enable_control_flow_v2 1419 def test_tensor_list_while_loop_stacked_cond_stacked_list_unknown_shape(self): 1420 1421 def loop_fn(i): 1422 handle = list_ops.tensor_list_reserve(None, 5, dtypes.int32) 1423 _, handle = control_flow_ops.while_loop( 1424 lambda j, _: j < 5, 1425 lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, 0)), 1426 (0, handle)) 1427 _, out_handle = control_flow_ops.while_loop( 1428 lambda j, _: j < i, 1429 lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)), 1430 (0, handle)) 1431 return list_ops.tensor_list_stack(out_handle, dtypes.int32) 1432 1433 self._test_loop_fn(loop_fn, 5) 1434 1435 @test_util.enable_control_flow_v2 1436 def test_tensor_list_while_loop_stacked_cond_unstacked_list(self): 1437 1438 def loop_fn(i): 1439 handle = list_ops.tensor_list_from_tensor([20, 21, 22, 23, 24], []) 1440 _, out_handle = control_flow_ops.while_loop( 1441 lambda j, _: j < i, 1442 lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)), 1443 (0, handle)) 1444 return list_ops.tensor_list_stack(out_handle, dtypes.int32) 1445 1446 self._test_loop_fn(loop_fn, 5) 1447 1448 def test_tensor_list_addn_already_stacked(self): 1449 1450 def loop_fn(i): 1451 l1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1452 l1 = list_ops.tensor_list_set_item(l1, 0, i) 1453 l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1454 l2 = list_ops.tensor_list_set_item(l2, 1, i) 1455 return list_ops.tensor_list_stack(math_ops.add_n([l1, l2]), dtypes.int32) 1456 1457 self._test_loop_fn(loop_fn, 2) 1458 1459 def test_tensor_list_addn_stacking_required(self): 1460 l1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1461 l1 = list_ops.tensor_list_set_item(l1, 1, 1) 1462 1463 def loop_fn(i): 1464 l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) 1465 l2 = list_ops.tensor_list_set_item(l2, 1, i) 1466 return list_ops.tensor_list_stack( 1467 math_ops.add_n([l1, l2]), dtypes.int32) 1468 1469 self._test_loop_fn(loop_fn, 2) 1470 1471 1472class OptionalTest(PForTestCase): 1473 1474 def test_optional_from_value(self): 1475 1476 def loop_fn(i): 1477 o = gen_dataset_ops.optional_from_value( 1478 [i, i + 1, constant_op.constant(3)]) 1479 gen_dataset_ops.optional_none() 1480 return gen_dataset_ops.optional_get_value( 1481 o, [dtypes.int32, dtypes.int32, dtypes.int32], 1482 [[], [], []]) 1483 1484 self._test_loop_fn(loop_fn, 2) 1485 1486 1487class StackTest(PForTestCase): 1488 1489 @test_util.run_v1_only("b/122612051") 1490 def test_stack_inside_loop_invariant(self): 1491 1492 def loop_fn(_): 1493 s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) 1494 op1 = data_flow_ops.stack_push_v2(s, 1) 1495 with ops.control_dependencies([op1]): 1496 op2 = data_flow_ops.stack_push_v2(s, 2) 1497 with ops.control_dependencies([op2]): 1498 e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1499 with ops.control_dependencies([e2]): 1500 e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1501 return e1, e2 1502 1503 self._test_loop_fn(loop_fn, 2) 1504 1505 @test_util.run_v1_only("b/122612051") 1506 def test_stack_inside_push_loop_dependent(self): 1507 1508 def loop_fn(i): 1509 s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) 1510 op1 = data_flow_ops.stack_push_v2(s, i) 1511 with ops.control_dependencies([op1]): 1512 op2 = data_flow_ops.stack_push_v2(s, 2) 1513 with ops.control_dependencies([op2]): 1514 e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1515 with ops.control_dependencies([e2]): 1516 e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1517 return e1, e2 1518 1519 self._test_loop_fn(loop_fn, 2) 1520 1521 @test_util.run_v1_only("b/122612051") 1522 def test_stack_outside_pop(self): 1523 s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) 1524 op = data_flow_ops.stack_push_v2(s, 5) 1525 with ops.control_dependencies([op]): 1526 op = data_flow_ops.stack_push_v2(s, 6) 1527 with ops.control_dependencies([op]): 1528 op = data_flow_ops.stack_push_v2(s, 7) 1529 1530 def loop_fn(_): 1531 e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1532 with ops.control_dependencies([e1]): 1533 e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1534 return e1, e2 1535 1536 with ops.control_dependencies([op]): 1537 e1, e2 = pfor_control_flow_ops.pfor(loop_fn, iters=2) 1538 with ops.control_dependencies([e1, e2]): 1539 e3 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) 1540 v1, v2, v3 = self._run_targets([e1, e2, e3], run_init=False) 1541 self.assertAllEqual([7, 7], v1) 1542 self.assertAllEqual([6, 6], v2) 1543 self.assertAllEqual(5, v3) 1544 1545 @test_util.run_v1_only("b/122612051") 1546 def test_stack_outside_push(self): 1547 s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) 1548 1549 def loop_fn(_): 1550 return data_flow_ops.stack_push_v2(s, 7) 1551 1552 with self.assertRaisesRegex(ValueError, "StackPushV2 not allowed.*"): 1553 pfor_control_flow_ops.pfor(loop_fn, iters=2) 1554 1555 1556# TODO(agarwal): test nested while_loops. This currently requires converting a 1557# tf.cond. 1558class WhileV1Test(PForTestCase): 1559 1560 def setUp(self): 1561 self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() 1562 control_flow_v2_toggles.disable_control_flow_v2() 1563 super(WhileV1Test, self).setUp() 1564 1565 def tearDown(self): 1566 if self._enabled: 1567 control_flow_v2_toggles.enable_control_flow_v2() 1568 super(WhileV1Test, self).tearDown() 1569 1570 def test_while_outside_loop(self): 1571 1572 x = control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) 1573 1574 def loop_fn(i): 1575 return x + i 1576 1577 self._test_loop_fn(loop_fn, 3) 1578 1579 @test_util.run_v1_only("b/122612051") 1580 def test_invariant_while(self): 1581 1582 def loop_fn(_): 1583 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) 1584 1585 self._test_loop_fn(loop_fn, 3) 1586 1587 @test_util.run_v1_only("b/122612051") 1588 def test_invariant_while_with_control_dependency(self): 1589 1590 def loop_fn(i): 1591 with ops.control_dependencies([i]): 1592 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, 1593 [0]) 1594 1595 self._test_loop_fn(loop_fn, 3) 1596 1597 @test_util.run_v1_only("b/122612051") 1598 def test_while_with_stateful_ops(self): 1599 1600 def loop_fn(_): 1601 return control_flow_ops.while_loop( 1602 lambda j, x: j < 4, lambda j, x: 1603 (j + 1, x + random_ops.random_uniform([])), [0, 0.])[0] 1604 1605 self._test_loop_fn(loop_fn, 3) 1606 1607 @test_util.run_v1_only("b/122612051") 1608 def test_while_unstacked_condition(self): 1609 1610 def loop_fn(i): 1611 return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x: 1612 (j + 1, x + i), [0, 0]) 1613 1614 self._test_loop_fn(loop_fn, 3) 1615 1616 @test_util.run_v1_only("b/122612051") 1617 def test_while(self): 1618 x = random_ops.random_uniform([3, 5]) 1619 lengths = constant_op.constant([4, 0, 2]) 1620 1621 def loop_fn(i): 1622 x_i = array_ops.gather(x, i) 1623 lengths_i = array_ops.gather(lengths, i) 1624 1625 _, total = control_flow_ops.while_loop( 1626 lambda j, _: j < lengths_i, lambda j, t: 1627 (j + 1, t + array_ops.gather(x_i, j)), [0, 0.]) 1628 return total 1629 1630 self._test_loop_fn(loop_fn, 3) 1631 1632 @test_util.run_v1_only("b/122612051") 1633 def test_while_jacobian(self): 1634 x = random_ops.random_uniform([1, 3]) 1635 y = random_ops.random_uniform([3, 3]) 1636 1637 # out = x @ y @ y @ y @ y, where @ is matmul operator. 1638 _, out = control_flow_ops.while_loop( 1639 lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)), 1640 [0, x]) 1641 1642 def loop_fn(i): 1643 out_i = array_ops.gather(out, i, axis=1) 1644 return array_ops.reshape(gradient_ops.gradients(out_i, x)[0], [-1]) 1645 1646 out = pfor_control_flow_ops.pfor(loop_fn, iters=3) 1647 1648 # The above code does not work with tf.while_loop instead of pfor. So we 1649 # manually compute the expected output here. 1650 # Note that gradient of output w.r.t is (y @ y @ y @ y)^T. 1651 expected_output = y 1652 for _ in range(3): 1653 expected_output = math_ops.matmul(expected_output, y) 1654 expected_output = array_ops.transpose(expected_output, [1, 0]) 1655 1656 with session.Session() as sess: 1657 out, expected = sess.run([out, expected_output]) 1658 self.assertAllClose(expected, out) 1659 1660 @test_util.run_v1_only("b/122612051") 1661 def test_tensor_array_as_loop_variable(self): 1662 1663 def loop_fn(i): 1664 1665 def body(j, ta): 1666 ta = ta.write(j, i + j * j) 1667 return j + 1, ta 1668 1669 _, ta = control_flow_ops.while_loop( 1670 lambda j, _: j < 4, body, 1671 (0, tensor_array_ops.TensorArray(dtypes.int32, size=4))) 1672 return ta.stack() 1673 1674 self._test_loop_fn(loop_fn, 3) 1675 1676 @test_util.run_v1_only("b/122612051") 1677 def test_read_tensor_array_partitioned_indices(self): 1678 # Note that tensor array values are pfor loop dependent, and the while loop 1679 # termination condition is also dependent on pfor iteration. 1680 def loop_fn(i): 1681 ta = tensor_array_ops.TensorArray(dtypes.int32, size=6) 1682 ta = ta.unstack(i + list(range(5))) 1683 1684 def body(j, s): 1685 return j + 1, s + ta.read(j) 1686 1687 _, s = control_flow_ops.while_loop(lambda j, _: j < i, body, (0, 0)) 1688 return s 1689 1690 self._test_loop_fn(loop_fn, 3) 1691 1692 @test_util.run_v1_only("b/122612051") 1693 def test_external_while_loop_grad(self): 1694 # Here we test that external while_loops that are extended from inside pfor 1695 # (due to gradient calls) are not actually converted. If the below was 1696 # converted all pfor iterations would write to the same tensor array 1697 # indices. 1698 x = constant_op.constant(1.) 1699 1700 def body(j, ta): 1701 ta = ta.write(j, x) 1702 return j + 1, ta 1703 1704 _, ta = control_flow_ops.while_loop( 1705 lambda j, _: j < 4, body, 1706 (0, tensor_array_ops.TensorArray(dtypes.float32, size=4))) 1707 out = ta.stack() 1708 1709 def loop_fn(i): 1710 out_i = array_ops.gather(out, i) 1711 return gradient_ops.gradients(out_i, x)[0] 1712 1713 with session.Session() as sess: 1714 # out is [x, x, x]. Hence the gradients should be [1, 1, 1]. 1715 self.assertAllEqual([1, 1, 1], 1716 sess.run(pfor_control_flow_ops.pfor(loop_fn, 3))) 1717 1718 @test_util.run_v1_only("b/122612051") 1719 def test_tensor_array_grad(self): 1720 inp = constant_op.constant(np.random.rand(3, 4, 2), dtype=dtypes.float32) 1721 ta = tensor_array_ops.TensorArray(dtypes.float32, size=3) 1722 ta = ta.unstack(inp) 1723 1724 def loop_fn(i): 1725 1726 def body(j, x): 1727 value = ta.gather([j]) 1728 value = array_ops.gather(array_ops.reshape(value, [4, 2]), i) 1729 return j + 1, x + value 1730 1731 _, out = control_flow_ops.while_loop(lambda j, _: j < 3, body, 1732 (0, array_ops.zeros([2]))) 1733 out = math_ops.reduce_prod(out) 1734 return out, gradient_ops.gradients(out, inp)[0] 1735 1736 pfor_out, pfor_out_grad = pfor_control_flow_ops.pfor(loop_fn, 4) 1737 # Note that tf.while_loop does not work in the setup above. So we manually 1738 # construct the equivalent computation of the above loops here. 1739 real_out = math_ops.reduce_sum(inp, axis=[0]) 1740 real_out = math_ops.reduce_prod(real_out, axis=[1]) 1741 # Note that gradients of real_out will accumulate the gradients across the 1742 # output value. Hence we do the same aggregation on pfor_out_grad. 1743 real_out_grad = gradient_ops.gradients(real_out, inp)[0] 1744 sum_pfor_out_grad = math_ops.reduce_sum(pfor_out_grad, axis=[0]) 1745 1746 with session.Session() as sess: 1747 v1, v2, v1_grad, v2_grad = sess.run( 1748 [pfor_out, real_out, sum_pfor_out_grad, real_out_grad]) 1749 self.assertAllClose(v1, v2) 1750 self.assertAllClose(v1_grad, v2_grad) 1751 1752 1753def dynamic_lstm_input_fn(batch_size, state_size, max_steps): 1754 # We make inputs and sequence_length constant so that multiple session.run 1755 # calls produce the same result. 1756 inputs = constant_op.constant( 1757 np.random.rand(batch_size, max_steps, state_size), dtype=dtypes.float32) 1758 sequence_length = np.random.randint(0, size=[batch_size], high=max_steps + 1) 1759 sequence_length = constant_op.constant(sequence_length, dtype=dtypes.int32) 1760 return inputs, sequence_length 1761 1762 1763def create_dynamic_lstm(cell_fn, batch_size, state_size, max_steps): 1764 cell = cell_fn(state_size) 1765 inputs, sequence_length = dynamic_lstm_input_fn(batch_size, state_size, 1766 max_steps) 1767 inputs_ta = tensor_array_ops.TensorArray( 1768 dtypes.float32, size=max_steps, element_shape=[batch_size, state_size]) 1769 inputs_time_major = array_ops.transpose(inputs, [1, 0, 2]) 1770 inputs_ta = inputs_ta.unstack(inputs_time_major) 1771 zeros = array_ops.zeros([state_size]) 1772 1773 def loop_fn(i): 1774 sequence_length_i = array_ops.gather(sequence_length, i) 1775 1776 def body_fn(t, state, ta): 1777 inputs_t = array_ops.expand_dims( 1778 array_ops.gather(inputs_ta.read(t), i), 0) 1779 output, new_state = cell(inputs_t, state) 1780 output = array_ops.reshape(output, [-1]) 1781 # TODO(agarwal): one optimization that dynamic_rnn uses is to avoid the 1782 # array_ops.where when t < min(sequence_length). Doing that requires 1783 # supporting tf.cond pfor conversion. 1784 done = t >= sequence_length_i 1785 output = array_ops.where(done, zeros, output) 1786 ta = ta.write(t, output) 1787 new_state = [ 1788 array_ops.where(done, s, ns) 1789 for s, ns in zip(nest.flatten(state), nest.flatten(new_state)) 1790 ] 1791 new_state = nest.pack_sequence_as(state, new_state) 1792 return t + 1, new_state, ta 1793 1794 def condition_fn(t, _, unused): 1795 del unused 1796 return t < max_steps 1797 1798 initial_state = cell.zero_state(1, dtypes.float32) 1799 _, state, ta = control_flow_ops.while_loop(condition_fn, body_fn, [ 1800 0, initial_state, 1801 tensor_array_ops.TensorArray(dtypes.float32, max_steps) 1802 ]) 1803 1804 new_state = [array_ops.reshape(x, [-1]) for x in nest.flatten(state)] 1805 new_state = nest.pack_sequence_as(initial_state, new_state) 1806 return ta.stack(), new_state 1807 1808 pfor_output = pfor_control_flow_ops.pfor(loop_fn, batch_size) 1809 tf_output = rnn.dynamic_rnn( 1810 cell, 1811 inputs, 1812 sequence_length=sequence_length, 1813 initial_state=cell.zero_state(batch_size, dtypes.float32)) 1814 return pfor_output, tf_output 1815 1816 1817@test_util.run_all_in_graph_and_eager_modes 1818class WhileV2Test(PForTestCase): 1819 1820 def setUp(self): 1821 self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() 1822 control_flow_v2_toggles.enable_control_flow_v2() 1823 super(WhileV2Test, self).setUp() 1824 1825 def tearDown(self): 1826 if not self._enabled: 1827 control_flow_v2_toggles.disable_control_flow_v2() 1828 super(WhileV2Test, self).tearDown() 1829 1830 def test_while_outside_loop(self): 1831 1832 def _f(): 1833 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) 1834 1835 def loop_fn(i): 1836 return _f() + i 1837 1838 self._test_loop_fn(loop_fn, 3) 1839 1840 def test_invariant_while(self): 1841 1842 def loop_fn(_): 1843 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) 1844 1845 self._test_loop_fn(loop_fn, 3) 1846 1847 def test_invariant_while_with_control_dependency(self): 1848 1849 def loop_fn(i): 1850 with ops.control_dependencies([i]): 1851 return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, 1852 [0]) 1853 1854 self._test_loop_fn(loop_fn, 3) 1855 1856 def test_while_with_stateful_ops(self): 1857 1858 def loop_fn(_): 1859 j, _ = control_flow_ops.while_loop( 1860 lambda j, x: j < 4, lambda j, x: 1861 (j + 1, x + random_ops.random_uniform([])), [0, 0.]) 1862 return j 1863 1864 self._test_loop_fn(loop_fn, 3) 1865 1866 def test_while_with_variable(self): 1867 v = resource_variable_ops.ResourceVariable(5.) 1868 1869 def loop_fn(_): 1870 _, output = control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x: 1871 (j + 1, x + v), [0, 0.]) 1872 return output 1873 1874 self._test_loop_fn(loop_fn, 3) 1875 1876 def test_while_unstacked_condition(self): 1877 1878 def loop_fn(i): 1879 return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x: 1880 (j + 1, x + i), [0, 0]) 1881 1882 self._test_loop_fn(loop_fn, 3) 1883 1884 def test_while(self): 1885 x = random_ops.random_uniform([3, 5]) 1886 lengths = constant_op.constant([4, 0, 2]) 1887 1888 def loop_fn(i): 1889 x_i = array_ops.gather(x, i) 1890 lengths_i = array_ops.gather(lengths, i) 1891 1892 return control_flow_ops.while_loop( 1893 lambda j, _: j < lengths_i, lambda j, t: 1894 (j + 1, t + array_ops.gather(x_i, j)), [0, 0.]) 1895 1896 self._test_loop_fn(loop_fn, 3) 1897 1898 def test_while_change_input_invariance(self): 1899 # This tests cases where a loop invariant input to while has loop dependent 1900 # operations applied to it inside the while body. 1901 # It also test inputs that are passed through. 1902 def loop_fn(i): 1903 return control_flow_ops.while_loop( 1904 lambda j, *_: j < i, lambda j, x, y, z, w: 1905 (j + 1, x + i, y + x, z, w), [ 1906 0, 1907 constant_op.constant(0), 1908 constant_op.constant(1), i, 1909 constant_op.constant(2) 1910 ]) 1911 1912 self._test_loop_fn(loop_fn, 3) 1913 1914 def test_while_shape_invariants(self): 1915 1916 def loop_fn(i): 1917 return control_flow_ops.while_loop( 1918 lambda j, *_: j < 4, 1919 lambda j, x, y: (j + 1, x + i, y + 1), 1920 [0, constant_op.constant([0, 1]), 1921 constant_op.constant([2, 3])], 1922 shape_invariants=[ 1923 None, 1924 tensor_shape.TensorShape([2]), 1925 tensor_shape.TensorShape([2]) 1926 ]) 1927 1928 self._test_loop_fn(loop_fn, 3) 1929 1930 def test_while_jacobian(self): 1931 # Note that we wrap the code below in a tf.function since we don't want the 1932 # while_loop call to be evaluated eagerly using a python loop. 1933 @def_function.function 1934 def _f(x, y, use_pfor): 1935 # out = x @ y @ y @ y @ y, where @ is matmul operator. 1936 _, out = control_flow_ops.while_loop( 1937 lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)), 1938 [0, x]) 1939 1940 def loop_fn(i): 1941 out_i = array_ops.gather(out, i, axis=1) 1942 grad = gradient_ops.gradients(out_i, x) 1943 return array_ops.reshape(grad[0], [-1]) 1944 1945 if use_pfor: 1946 return pfor_control_flow_ops.pfor(loop_fn, iters=3) 1947 else: 1948 return pfor_control_flow_ops.for_loop( 1949 loop_fn, iters=3, loop_fn_dtypes=out.dtype) 1950 1951 x = constant_op.constant(np.random.uniform(size=(1, 3))) 1952 y = constant_op.constant(np.random.uniform(size=(3, 3))) 1953 self.assertAllClose(_f(x, y, True), _f(x, y, False)) 1954 1955 def test_scan(self): 1956 np.random.seed(seed=42) 1957 data = np.random.randn(3).astype(np.float32) 1958 1959 def log_prob(x): 1960 return math_ops.reduce_sum(functional_ops.scan_v2( 1961 lambda _, yi: (x - yi)**2, 1962 elems=data, 1963 initializer=constant_op.constant(0.))) 1964 1965 x = variables.Variable(array_ops.ones([2])) 1966 self.evaluate(x.initializer) 1967 v_log_prob = lambda x: pfor_control_flow_ops.vectorized_map(log_prob, x) 1968 theoretical, numerical = gradient_checker_v2.compute_gradient( 1969 v_log_prob, (x,), delta=1e-3) 1970 self.assertAllClose(theoretical, numerical, rtol=1e-2) 1971 1972 def test_scan_captured_variable(self): 1973 if not context.executing_eagerly(): 1974 self.skipTest("Test only written for 2.x") 1975 v = variables.Variable(math_ops.range(10, dtype=dtypes.float32)) 1976 1977 def loop_fn(idx): 1978 del idx 1979 return functional_ops.scan_v2(lambda _, i: array_ops.gather(v, i), 1980 elems=math_ops.range(v.shape[0]), 1981 initializer=0.0) 1982 with backprop.GradientTape() as tape: 1983 result = pfor_control_flow_ops.pfor(loop_fn, 2) 1984 self.assertAllClose([2.] * 10, tape.gradient(result, v)) 1985 1986 1987@test_util.run_all_in_graph_and_eager_modes 1988class NestedControlFlowTest(PForTestCase): 1989 1990 def setUp(self): 1991 self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() 1992 control_flow_v2_toggles.enable_control_flow_v2() 1993 super(NestedControlFlowTest, self).setUp() 1994 1995 def tearDown(self): 1996 if not self._enabled: 1997 control_flow_v2_toggles.disable_control_flow_v2() 1998 super(NestedControlFlowTest, self).tearDown() 1999 2000 def _cond(self, f=None, split=0): 2001 if f is None: 2002 f = lambda x, y: (x, y) 2003 2004 def _f(x, y): 2005 return control_flow_ops.cond(y > split, lambda: f(x, y), lambda: 2006 (x + 1., y)) 2007 2008 return _f 2009 2010 def _while(self, f=None): 2011 if f is None: 2012 f = lambda x, y: (x, y) 2013 2014 def _f(x, y): 2015 return control_flow_ops.while_loop( 2016 lambda j, _: j < y, lambda j, t: 2017 (j + 1, t + array_ops.gather(f(x, y)[0], j)), [0, x])[1], y 2018 2019 return _f 2020 2021 def _test_helper(self, f): 2022 x = random_ops.random_uniform([5, 5]) 2023 y = constant_op.constant([4, -1, 2, -2, 2]) 2024 2025 def loop_fn(i): 2026 x_i = array_ops.gather(x, i) 2027 y_i = array_ops.gather(y, i) 2028 return f(x_i, y_i) 2029 2030 self._test_loop_fn(loop_fn, 5) 2031 2032 def test_cond_while(self): 2033 self._test_helper(self._cond(self._while())) 2034 2035 def test_while_cond(self): 2036 self._test_helper(self._while(self._cond())) 2037 2038 def test_while_while(self): 2039 self._test_helper(self._while(self._while())) 2040 2041 def test_cond_cond(self): 2042 self._test_helper(self._cond(self._cond())) 2043 2044 2045@test_util.run_all_in_graph_and_eager_modes 2046@test_util.with_control_flow_v2 2047class StatelessIfTest(PForTestCase): 2048 2049 def test_loop_variant_cond(self): 2050 x = [1, 2, 3, 4, 5.] 2051 y = 2.5 2052 2053 @def_function.function 2054 def loop_fn(i): 2055 x_i = array_ops.gather(x, i) 2056 # Note that the output has a combination of then and else branches being 2057 # loop variant / invariant. 2058 return cond_v2.cond_v2(x_i < y, lambda: (y - x_i, y, 1., 2.), lambda: 2059 (x_i - y, 0., y, 3.)) 2060 2061 self._test_loop_fn(loop_fn, iters=5) 2062 2063 def test_loop_invariant_cond(self): 2064 x = [1, 2, 3, 4, 5.] 2065 y = 0.5 2066 z = random_ops.random_uniform([]) 2067 2068 @def_function.function 2069 def loop_fn(i): 2070 x_i = array_ops.gather(x, i) 2071 # Note that the output has a combination of then and else branches being 2072 # loop variant / invariant. 2073 return cond_v2.cond_v2(z < y, lambda: (y - x_i, y, 1., 2.), lambda: 2074 (x_i - y, 0., y, 3.)) 2075 2076 self._test_loop_fn(loop_fn, iters=5) 2077 2078 def test_empty_branch(self): 2079 x = [1, 2, 3, 4, 5.] 2080 y = 6. 2081 2082 @def_function.function 2083 def loop_fn(i): 2084 x_i = array_ops.gather(x, i) 2085 return cond_v2.cond_v2( 2086 x_i < y, # Note that else branch is empty. 2087 lambda: (y - x_i, y, 1., 2.), 2088 lambda: (x_i - y, 0., y, 3.)) 2089 2090 self._test_loop_fn(loop_fn, iters=5) 2091 2092 2093@test_util.run_all_in_graph_and_eager_modes 2094@test_util.with_control_flow_v2 2095class IfTest(PForTestCase): 2096 2097 def test_read_var(self): 2098 self.skipTest("b/156438918") # Flaky 2099 2100 x = [1, 2, 3, 4, 5.] 2101 y = 2.5 2102 z = resource_variable_ops.ResourceVariable(5.) 2103 2104 @def_function.function 2105 def loop_fn(i): 2106 x_i = array_ops.gather(x, i) 2107 return cond_v2.cond_v2(x_i < y, lambda: z - x_i, lambda: z + x_i) 2108 2109 self._test_loop_fn(loop_fn, iters=5) 2110 2111 2112class RNNTest(PForTestCase): 2113 2114 @test_util.run_v1_only("b/122612051") 2115 def test_dynamic_rnn(self): 2116 pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, 3, 5, 2117 7) 2118 self.run_and_assert_equal(pfor_outputs, tf_outputs) 2119 2120 @test_util.run_v1_only("b/122612051") 2121 def test_dynamic_lstm(self): 2122 pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicLSTMCell, 3, 5, 2123 7) 2124 self.run_and_assert_equal(pfor_outputs, tf_outputs) 2125 2126 2127# TODO(agarwal): benchmark numbers on GPU for graphs based on while_loop 2128# conversion don't look good. Some of it seems like lot of copies between host 2129# and device. Optimize that. 2130class Benchmarks(test.Benchmark): 2131 2132 def _run(self, targets, iters, name=None): 2133 2134 def _done(t): 2135 # Note that we don't use tf.control_dependencies since that will not make 2136 # sure that the computation on GPU has actually finished. So we fetch the 2137 # first element of the output, and assume that this will not be called on 2138 # empty tensors. 2139 return array_ops.gather(array_ops.reshape(t, [-1]), 0) 2140 2141 targets = [_done(x) for x in nest.flatten(targets)] 2142 sess = session.Session() 2143 with sess: 2144 init = variables.global_variables_initializer() 2145 sess.run(init) 2146 run_fn = sess.make_callable(targets) 2147 run_fn() # Warm up 2148 begin = time.time() 2149 for _ in range(iters): 2150 run_fn() 2151 end = time.time() 2152 avg_time_ms = 1000 * (end - begin) / iters 2153 self.report_benchmark(iters=iters, wall_time=avg_time_ms, name=name) 2154 return avg_time_ms 2155 2156 def benchmark_sess_run_overhead(self): 2157 with ops.Graph().as_default(): 2158 x = constant_op.constant(1.0) 2159 self._run(x, 10000, name="session_run_overhead") 2160 2161 def benchmark_add(self): 2162 with ops.Graph().as_default(): 2163 n = 256 2164 params = 1000 2165 x = random_ops.random_normal([n, params]) 2166 y = random_ops.random_normal([n, params]) 2167 2168 def loop_fn(i): 2169 x_i = array_ops.gather(x, i) 2170 y_i = array_ops.gather(y, i) 2171 return x_i + y_i 2172 2173 pfor_outputs = pfor_control_flow_ops.pfor(loop_fn, n) 2174 while_outputs = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, n) 2175 manual = x + y 2176 2177 self._run(manual, 1000, name="manual_add") 2178 self._run(pfor_outputs, 1000, name="pfor_add") 2179 self._run(while_outputs, 100, name="while_add") 2180 2181 def benchmark_matmul(self): 2182 with ops.Graph().as_default(): 2183 n = 1024 2184 params = 1000 2185 x = random_ops.random_normal([n, params]) 2186 y = random_ops.random_normal([params, params]) 2187 2188 def loop_fn(i): 2189 x_i = array_ops.expand_dims(array_ops.gather(x, i), 0) 2190 return math_ops.matmul(x_i, y) 2191 2192 pfor_outputs = pfor_control_flow_ops.pfor(loop_fn, n) 2193 while_outputs = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, n) 2194 manual = math_ops.matmul(x, y) 2195 2196 self._run(manual, 1000, name="manual_matmul") 2197 self._run(pfor_outputs, 1000, name="pfor_matmul") 2198 self._run(while_outputs, 100, name="while_matmul") 2199 2200 def benchmark_map_fn(self): 2201 with ops.Graph().as_default(): 2202 b = 256 2203 params = 1000 2204 inp = random_ops.random_normal((b, params)) 2205 fn = lambda x: x * x 2206 2207 def pfor_map_fn(f, x): 2208 return pfor_control_flow_ops.pfor(lambda i: f(array_ops.gather(x, i)), 2209 array_ops.shape(x)[0]) 2210 2211 map_output = map_fn.map_fn(fn, inp) 2212 pfor_output = pfor_map_fn(fn, inp) 2213 2214 self._run(map_output, 100, name="tf_map_fn") 2215 self._run(pfor_output, 100, name="pfor_map_fn") 2216 2217 def benchmark_basic_while(self): 2218 with ops.Graph().as_default(): 2219 2220 def loop_fn(i): 2221 _, s = control_flow_ops.while_loop(lambda t, x: t < i, lambda t, x: 2222 (t + 1, x + i), [0, 0]) 2223 return s 2224 2225 iters = 50 2226 pfor_output = pfor_control_flow_ops.pfor(loop_fn, iters) 2227 for_loop_output = pfor_control_flow_ops.for_loop(loop_fn, dtypes.int32, 2228 iters) 2229 self._run(pfor_output, 100, name="pfor_basic") 2230 self._run(for_loop_output, 100, name="for_loop_basic") 2231 2232 def benchmark_dynamic_rnn(self): 2233 with ops.Graph().as_default(): 2234 pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, 128, 2235 512, 16) 2236 self._run(pfor_outputs, 100, name="pfor_rnn") 2237 self._run(tf_outputs, 100, name="tf_rnn") 2238 2239 def benchmark_reduction(self): 2240 n = 1024 2241 with ops.Graph().as_default(): 2242 x = random_ops.random_uniform([n, n]) 2243 w = random_ops.random_uniform([n, n]) 2244 2245 def loop_fn(i, pfor_config): 2246 x_i = array_ops.gather(x, i) 2247 return math_ops.reduce_sum( 2248 math_ops.matmul(pfor_config.reduce_concat(x_i), w)) 2249 2250 # Note that output_reduction will be tiled, so there may be some minor 2251 # overheads compared to output_no_reduction. 2252 output_reduction = pfor_control_flow_ops.pfor(loop_fn, n) 2253 output_no_reduction = math_ops.reduce_sum(math_ops.matmul(x, w)) 2254 # Benchmark to test that reduction does not add overhead and its output is 2255 # treated as loop invariant. 2256 self._run(output_reduction, 30, name="matmul_reduction") 2257 self._run(output_no_reduction, 30, name="matmul_no_reduction") 2258 2259 2260class SparseTest(PForTestCase): 2261 2262 @test_util.run_v1_only("b/122612051") 2263 def test_var_loop_len(self): 2264 num_iters = array_ops.placeholder(dtypes.int32) 2265 2266 def loop_fn(_): 2267 return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6], 2268 [3]) # [0, 2, 0] 2269 2270 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2271 with self.cached_session() as sess: 2272 sess.run(pfor, feed_dict={num_iters: 3}) 2273 2274 @test_util.run_v1_only("b/122612051") 2275 def test_sparse_result_none_stacked(self): 2276 num_iters = 10 2277 2278 def loop_fn(_): 2279 return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6], 2280 [3]) # [0, 2, 0] 2281 2282 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2283 2284 indices = [[i, j] for i in range(num_iters) for j in range(3)] 2285 values = [4, 5, 6] * num_iters 2286 dense_shapes = [num_iters, 3] 2287 # Expected result: [[4, 5, 6], [4, 5, 6], [4, 5, 6], ...] 2288 manual = sparse_tensor.SparseTensor(indices, values, dense_shapes) 2289 self.run_and_assert_equal(pfor, manual) 2290 2291 @test_util.run_v1_only("b/122612051") 2292 def test_sparse_result_all_stacked(self): 2293 num_iters = 10 2294 2295 def loop_fn(i): 2296 i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) 2297 indices = array_ops.expand_dims(i, 0) 2298 return sparse_tensor.SparseTensor(indices, i, i + 1) # [0, ..., 0, i] 2299 2300 # Expected result: [[0], [0, 1], [0, 0, 2], [0, 0, 0, 3], ...] 2301 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2302 manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)], 2303 list(range(num_iters)), 2304 (num_iters, num_iters)) 2305 self.run_and_assert_equal(pfor, manual) 2306 2307 @test_util.run_v1_only("b/122612051") 2308 def test_sparse_result_indices_stacked(self): 2309 num_iters = 10 2310 2311 def loop_fn(i): 2312 i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) 2313 indices = array_ops.expand_dims(i, 0) 2314 return sparse_tensor.SparseTensor(indices, [1], [num_iters]) 2315 2316 # Expected result: identity matrix size num_iters * num_iters 2317 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2318 manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)], 2319 [1] * num_iters, (num_iters, num_iters)) 2320 self.run_and_assert_equal(pfor, manual) 2321 2322 @test_util.run_v1_only("b/122612051") 2323 def test_sparse_result_values_stacked(self): 2324 num_iters = 10 2325 2326 def loop_fn(i): 2327 i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) 2328 return sparse_tensor.SparseTensor([[0]], i, [num_iters]) # [i, 0, ..., 0] 2329 2330 # Expected result: [[1, 0, ...], [2, 0, ...], [3, 0, ...], ...] 2331 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2332 manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)], 2333 list(range(num_iters)), 2334 (num_iters, num_iters)) 2335 self.run_and_assert_equal(pfor, manual) 2336 2337 @test_util.run_v1_only("b/122612051") 2338 def test_sparse_result_shapes_stacked(self): 2339 num_iters = 10 2340 2341 def loop_fn(i): 2342 i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) 2343 return sparse_tensor.SparseTensor([[0]], [1], i + 1) # [1, 0, ..., 0] 2344 2345 # Expected result: [[1, 0, 0, ...], [1, 0, 0, ...], ...] 2346 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2347 manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)], 2348 [1] * num_iters, (num_iters, num_iters)) 2349 self.run_and_assert_equal(pfor, manual) 2350 2351 @test_util.run_v1_only("b/122612051") 2352 def test_sparse_result_shapes_stacked_2D(self): 2353 num_iters = 10 2354 2355 def loop_fn(i): 2356 i = array_ops.expand_dims(math_ops.cast(i + 1, dtypes.int64), 0) 2357 shape = array_ops.concat([i, i], 0) 2358 return sparse_tensor.SparseTensor([[0, 0]], [1], shape) # [1, 0, ..., 0] 2359 2360 # Expected result: [[[1, 0, ...], [0, ..., 0], [0, ..., 0], ...], ...] 2361 pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) 2362 manual = sparse_tensor.SparseTensor([[i, 0, 0] for i in range(num_iters)], 2363 [1] * num_iters, 2364 (num_iters, num_iters, num_iters)) 2365 self.run_and_assert_equal(pfor, manual) 2366 2367 2368# Dummy CompositeTensor to test CompositeTensor support. 2369class Particle(composite_tensor.CompositeTensor): 2370 """A (batch of) particles each defined by a mass and a scalar velocity.""" 2371 2372 def __init__(self, mass, velocity): 2373 mass = ops.convert_to_tensor(mass) 2374 velocity = ops.convert_to_tensor(velocity) 2375 self.shape = array_ops.broadcast_static_shape(mass.shape, velocity.shape) 2376 self.mass = mass 2377 self.velocity = velocity 2378 2379 @property 2380 def _type_spec(self): 2381 return ParticleSpec( 2382 type_spec.type_spec_from_value(self.mass), 2383 type_spec.type_spec_from_value(self.velocity)) 2384 2385 2386class ParticleSpec(type_spec.BatchableTypeSpec): 2387 2388 def __init__(self, mass, velocity): 2389 self.shape = array_ops.broadcast_static_shape( 2390 mass.shape, velocity.shape) 2391 self.mass = mass 2392 self.velocity = velocity 2393 2394 def _serialize(self): 2395 return (self.mass, self.velocity) 2396 2397 @property 2398 def value_type(self): 2399 return Particle 2400 2401 @property 2402 def _component_specs(self): 2403 return (self.mass, self.velocity) 2404 2405 def _to_components(self, value): 2406 return (value.mass, value.velocity) 2407 2408 def _from_components(self, components): 2409 return Particle(*components) 2410 2411 def _pad_shape_to_full_rank(self, s): 2412 """Pad component shapes with 1's so all components have the same rank.""" 2413 return tensor_shape.TensorShape( 2414 [1] * (self.shape.ndims - s.ndims)).concatenate(s) 2415 2416 def _batch(self, batch_size): 2417 return ParticleSpec( 2418 mass=tensor_spec.TensorSpec( 2419 dtype=self.mass.dtype, 2420 shape=tensor_shape.TensorShape([batch_size]).concatenate( 2421 self._pad_shape_to_full_rank(self.mass.shape))), 2422 velocity=tensor_spec.TensorSpec( 2423 dtype=self.velocity.dtype, 2424 shape=tensor_shape.TensorShape([batch_size]).concatenate( 2425 self._pad_shape_to_full_rank(self.velocity.shape)))) 2426 2427 def _unbatch(self): 2428 return ParticleSpec( 2429 tensor_spec.TensorSpec(dtype=self.mass.dtype, 2430 shape=self.mass.shape[1:]), 2431 tensor_spec.TensorSpec(dtype=self.velocity.dtype, 2432 shape=self.velocity.shape[1:])) 2433 2434 def _to_tensor_list(self, value): 2435 return [array_ops.reshape( 2436 value.mass, 2437 self._pad_shape_to_full_rank(value.mass.shape)), 2438 array_ops.reshape( 2439 value.velocity, 2440 self._pad_shape_to_full_rank(value.velocity.shape))] 2441 2442 2443class CompositeTensorTest(PForTestCase, parameterized.TestCase): 2444 2445 @parameterized.parameters((None,), (3,)) 2446 def test_create_composite_inside_loop(self, parallel_iterations): 2447 num_particles = 10 2448 velocities = random_ops.random_uniform([num_particles]) 2449 particles = pfor_control_flow_ops.pfor( 2450 # Build a batch of particles all with the same mass. 2451 lambda i: Particle(mass=4., velocity=array_ops.gather(velocities, i)), 2452 num_particles, 2453 parallel_iterations=parallel_iterations) 2454 particles_mass, particles_velocity, velocities = self.evaluate( 2455 (particles.mass, particles.velocity, velocities)) 2456 self.assertAllEqual(particles_mass, 4. * np.ones([num_particles])) 2457 self.assertAllEqual(particles_velocity, velocities) 2458 2459 @parameterized.parameters((None,), (3,)) 2460 def test_composite_is_converted_to_batched_tensor( 2461 self, parallel_iterations): 2462 particles = pfor_control_flow_ops.pfor( 2463 lambda _: Particle(mass=random_ops.random_uniform([3]), # pylint: disable=g-long-lambda 2464 velocity=random_ops.random_uniform([5, 3])), 2465 4, 2466 parallel_iterations=parallel_iterations) 2467 # Naively batching the component shapes would give `[4, 3]` and `[4, 5, 3]` 2468 # which have no consistent broadcast shape. 2469 self.assertTrue(particles.mass.shape, [4, 1, 3]) 2470 self.assertAllEqual(particles.velocity.shape, [4, 5, 3]) 2471 2472 def test_vectorized_map_gathers_composite_tensors(self): 2473 particles = Particle(mass=[1., 2., 3., 4., 5.], 2474 velocity=[1., 2., 3., 4., 5.]) 2475 self.assertAllEqual( 2476 pfor_control_flow_ops.vectorized_map( 2477 lambda x: x.mass * x.velocity, particles), 2478 particles.mass * particles.velocity) 2479 2480 def test_vectorized_map_of_ragged_tensors(self): 2481 # Vmap should be able to handle ragged Tensors as long as they're not 2482 # *actually* ragged. 2483 ragged = ragged_tensor.RaggedTensor.from_uniform_row_length( 2484 ragged_tensor.RaggedTensor.from_row_lengths( 2485 values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 2486 row_lengths=[3, 3, 3, 3]), 2487 uniform_row_length=2) # Overall shape [2, 2, 3]. 2488 self.assertAllEqual( 2489 pfor_control_flow_ops.vectorized_map( 2490 lambda x: x.to_tensor(shape=[2, 3]), ragged), 2491 ragged.to_tensor(shape=[2, 2, 3])) 2492 2493 2494class ParsingTest(PForTestCase): 2495 2496 def test_decode_csv(self): 2497 csv_tensor = constant_op.constant([["1:2:3"], ["::"], ["7:8:9"]]) 2498 kwargs = {"record_defaults": [[10], [20], [30]], "field_delim": ":"} 2499 2500 def loop_fn(i): 2501 line = array_ops.gather(csv_tensor, i) 2502 return parsing_ops.decode_csv(line, **kwargs) 2503 2504 self._test_loop_fn(loop_fn, iters=3) 2505 2506 @test_util.run_v1_only("b/122612051") 2507 def test_parse_single_example(self): 2508 2509 def _int64_feature(*values): 2510 return feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=values)) 2511 2512 def _bytes_feature(*values): 2513 return feature_pb2.Feature( 2514 bytes_list=feature_pb2.BytesList( 2515 value=[v.encode("utf-8") for v in values])) 2516 2517 examples = constant_op.constant([ 2518 example_pb2.Example( 2519 features=feature_pb2.Features( 2520 feature={ 2521 "dense_int": _int64_feature(i), 2522 "dense_str": _bytes_feature(str(i)), 2523 "sparse_int": _int64_feature(i, i * 2, i * 4, i * 8), 2524 "sparse_str": _bytes_feature(*["abc"] * i) 2525 })).SerializeToString() for i in range(10) 2526 ]) 2527 2528 features = { 2529 "dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0), 2530 "dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""), 2531 "sparse_int": parsing_ops.VarLenFeature(dtypes.int64), 2532 "sparse_str": parsing_ops.VarLenFeature(dtypes.string), 2533 } 2534 2535 def loop_fn(i): 2536 example_proto = array_ops.gather(examples, i) 2537 f = parsing_ops.parse_single_example(example_proto, features) 2538 return f 2539 2540 pfor = pfor_control_flow_ops.pfor(loop_fn, iters=10) 2541 manual = parsing_ops.parse_example(examples, features) 2542 self.run_and_assert_equal(pfor, manual) 2543 2544 2545class PartitionedCallTest(PForTestCase): 2546 2547 def test_simple(self): 2548 2549 @def_function.function 2550 def f(x): 2551 return math_ops.square(x) + 1 2552 2553 z = random_ops.random_uniform([4]) 2554 2555 def loop_fn(i): 2556 return f(array_ops.gather(z, i)) 2557 2558 self._test_loop_fn(loop_fn, 4) 2559 2560 def test_nested_calls(self): 2561 2562 @def_function.function 2563 def inner(x): 2564 return math_ops.square(x) 2565 2566 @def_function.function 2567 def outer(y): 2568 return math_ops.reduce_sum(inner(y)) + 2 2569 2570 z = random_ops.random_uniform([4, 2]) 2571 2572 def loop_fn(i): 2573 return outer(array_ops.gather(z, i)) 2574 2575 self._test_loop_fn(loop_fn, 4) 2576 2577 def test_nested_definition(self): 2578 2579 @def_function.function 2580 def outer(y): 2581 2582 @def_function.function 2583 def inner(x): 2584 return math_ops.square(x) + 1 2585 2586 return math_ops.reduce_sum(inner(y)) + 2 2587 2588 z = random_ops.random_uniform([4, 2]) 2589 2590 def loop_fn(i): 2591 return outer(array_ops.gather(z, i)) 2592 2593 self._test_loop_fn(loop_fn, 4) 2594 2595 def test_gradients(self): 2596 2597 @def_function.function 2598 def f(x): 2599 return math_ops.square(x) + 1 2600 2601 z = random_ops.random_uniform([4, 2]) 2602 2603 def loop_fn(i): 2604 z_i = array_ops.gather(z, i) 2605 with backprop.GradientTape() as g: 2606 g.watch(z_i) 2607 out = f(z_i) 2608 return out, g.gradient(out, z_i) 2609 2610 self._test_loop_fn(loop_fn, 4) 2611 2612 def test_stateful_with_gradients(self): 2613 2614 z = random_ops.random_uniform([4, 2]) 2615 v = variables.Variable(z[0]) 2616 2617 @def_function.function 2618 def f(x): 2619 return math_ops.square(x) + v + 1 2620 2621 def loop_fn(i): 2622 z_i = array_ops.gather(z, i) 2623 with backprop.GradientTape() as g: 2624 g.watch(z_i) 2625 out = f(z_i) 2626 return out, g.gradient(out, z_i) 2627 2628 self._test_loop_fn(loop_fn, 4) 2629 2630 2631class SpectralTest(PForTestCase, parameterized.TestCase): 2632 2633 @parameterized.parameters( 2634 (fft_ops.fft,), 2635 (fft_ops.fft2d,), 2636 (fft_ops.fft3d,), 2637 (fft_ops.ifft,), 2638 (fft_ops.ifft2d,), 2639 (fft_ops.ifft3d,), 2640 ) 2641 def test_fft(self, op_func): 2642 shape = [2, 3, 4, 3, 4] 2643 x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape) 2644 2645 def loop_fn(i): 2646 x_i = array_ops.gather(x, i) 2647 return op_func(x_i) 2648 2649 self._test_loop_fn(loop_fn, 2) 2650 2651 @parameterized.parameters( 2652 (fft_ops.rfft,), 2653 (fft_ops.rfft2d,), 2654 (fft_ops.rfft3d,), 2655 ) 2656 @test.disable_with_predicate( 2657 pred=test.is_built_with_rocm, 2658 skip_message="Disable subtest on ROCm due to rocfft issues") 2659 def test_rfft(self, op_func): 2660 for dtype in (dtypes.float32, dtypes.float64): 2661 x = random_ops.random_uniform([2, 3, 4, 3, 4], dtype=dtype) 2662 2663 # pylint: disable=cell-var-from-loop 2664 def loop_fn(i): 2665 x_i = array_ops.gather(x, i) 2666 return op_func(x_i) 2667 2668 # pylint: enable=cell-var-from-loop 2669 2670 self._test_loop_fn(loop_fn, 2) 2671 2672 @parameterized.parameters( 2673 (fft_ops.irfft,), 2674 (fft_ops.irfft2d,), 2675 (fft_ops.irfft3d,), 2676 ) 2677 @test.disable_with_predicate( 2678 pred=test.is_built_with_rocm, 2679 skip_message="Disable subtest on ROCm due to rocfft issues") 2680 def test_irfft(self, op_func): 2681 if config.list_physical_devices("GPU"): 2682 # TODO(b/149957923): The test is flaky 2683 self.skipTest("b/149957923: irfft vectorization flaky") 2684 for dtype in (dtypes.complex64, dtypes.complex128): 2685 shape = [2, 3, 4, 3, 4] 2686 x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape) 2687 x = math_ops.cast(x, dtype=dtype) 2688 2689 # pylint: disable=cell-var-from-loop 2690 def loop_fn(i): 2691 x_i = array_ops.gather(x, i) 2692 return op_func(x_i) 2693 2694 # pylint: enable=cell-var-from-loop 2695 2696 self._test_loop_fn(loop_fn, 2) 2697 2698 2699class VariableTest(PForTestCase): 2700 2701 def test_create_variable_once(self): 2702 x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32) 2703 y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32) 2704 a_var = [] 2705 2706 def f(z): 2707 if not a_var: 2708 a_var.append(variables.Variable(lambda: y, name="a")) 2709 return math_ops.matmul(z, a_var[0] / 16) 2710 2711 pfor_control_flow_ops.vectorized_map(f, x) 2712 2713 @test_util.run_v2_only 2714 def test_create_variable_repeated(self): 2715 x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32) 2716 y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32) 2717 2718 def f(z): 2719 a_var = variables.Variable(lambda: y, name="a") / 4 2720 return math_ops.matmul(z, a_var / 16) 2721 2722 # Note that this error is only raised under v2 behavior. 2723 with self.assertRaisesRegex( 2724 ValueError, 2725 "tf.function-decorated function tried to create variables on non-first" 2726 ): 2727 pfor_control_flow_ops.vectorized_map(f, x) 2728 2729 @test_util.run_all_in_graph_and_eager_modes 2730 def test_variable_shape(self): 2731 v = resource_variable_ops.ResourceVariable([1, 2]) 2732 2733 def loop_fn(_): 2734 return resource_variable_ops.variable_shape(v.handle) 2735 2736 self._test_loop_fn(loop_fn, 2) 2737 2738 2739if __name__ == "__main__": 2740 test.main() 2741