1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for custom training loops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22 23from tensorflow.python import tf2 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.distribute import combinations 26from tensorflow.python.distribute import device_util 27from tensorflow.python.distribute import distribute_lib 28from tensorflow.python.distribute import reduce_util 29from tensorflow.python.distribute import strategy_combinations 30from tensorflow.python.distribute import test_util 31from tensorflow.python.eager import def_function 32from tensorflow.python.eager import test 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import errors 36from tensorflow.python.framework import ops 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import control_flow_ops 39from tensorflow.python.ops import map_fn 40from tensorflow.python.ops import math_ops 41from tensorflow.python.ops import variables 42from tensorflow.python.ops.losses import losses 43from tensorflow.python.tpu import tpu 44from tensorflow.python.util import nest 45 46 47def get_dataset_from_tensor_slices(inp_array): 48 dataset = dataset_ops.DatasetV2.from_tensor_slices(inp_array) 49 # TODO(b/138326910): Remove Dataset V1 version once bug resolved. 50 if not tf2.enabled(): 51 dataset = dataset_ops.Dataset.from_tensor_slices(inp_array) 52 return dataset 53 54 55class AssertFlattenedMixin(object): 56 """Mixin for specialized asserts.""" 57 58 def assert_equal_flattened(self, expected_results, actual_results): 59 """Asserts that flattened results are equal. 60 61 Due to the number of replicas in the strategy, the output may have a 62 different structure and needs to be flattened for comparison. 63 64 Args: 65 expected_results: The results expected as a result of a computation. 66 actual_results: The actual results of a computation. 67 """ 68 self.assertEqual(len(expected_results), len(actual_results)) 69 70 for i, expected_result in enumerate(expected_results): 71 final_result = [] 72 actual_result = actual_results[i] 73 for val in actual_result: 74 final_result.extend(val.numpy()) 75 self.assertAllEqual(expected_result, final_result) 76 77 78class InputIterationTest(test.TestCase, parameterized.TestCase, 79 AssertFlattenedMixin): 80 81 @combinations.generate( 82 combinations.combine( 83 distribution=strategy_combinations.all_strategies, 84 mode=["eager"] 85 )) 86 def testConstantNumpyInput(self, distribution): 87 88 @def_function.function 89 def run(x): 90 91 def computation(x): 92 return math_ops.square(x) 93 94 outputs = distribution.experimental_local_results( 95 distribution.run(computation, args=(x,))) 96 return outputs 97 98 self.assertAllEqual( 99 constant_op.constant(4., shape=(distribution.num_replicas_in_sync)), 100 run(2.)) 101 102 @combinations.generate( 103 combinations.combine( 104 distribution=strategy_combinations.all_strategies, 105 mode=["eager"] 106 )) 107 def testStatefulExperimentalRunAlwaysExecute(self, distribution): 108 with distribution.scope(): 109 v = variables.Variable( 110 0.0, aggregation=variables.VariableAggregation.MEAN) 111 112 @def_function.function 113 def train_step(): 114 115 def assign_add(): 116 v.assign_add(1.0) 117 118 distribution.run(assign_add) 119 return array_ops.zeros([]) 120 121 train_step() 122 self.assertAllEqual(1.0, v.numpy()) 123 124 @combinations.generate( 125 combinations.combine( 126 distribution=strategy_combinations.strategies_minus_tpu, 127 mode=["eager"])) 128 def testFullEager(self, distribution): 129 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 130 131 def train_step(data): 132 return math_ops.square(data) 133 134 dist_dataset = distribution.experimental_distribute_dataset(dataset) 135 results = [] 136 for x in dist_dataset: 137 output = distribution.experimental_local_results( 138 distribution.run(train_step, args=(x,))) 139 results.append(output) 140 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 141 142 @combinations.generate( 143 combinations.combine( 144 distribution=strategy_combinations.all_strategies, mode=["eager"])) 145 def testGetNextAsOptional(self, distribution): 146 data = [5., 6., 7., 8.] 147 dataset = get_dataset_from_tensor_slices(data).batch(2) 148 dist_dataset = distribution.experimental_distribute_dataset(dataset) 149 iterator = iter(dist_dataset) 150 151 def train_step(data): 152 return math_ops.square(data) 153 154 @def_function.function 155 def run(iterator): 156 return distribution.experimental_local_results( 157 distribution.run( 158 train_step, args=(iterator.get_next_as_optional().get_value(),))) 159 160 self.assert_equal_flattened([[25., 36.]], [run(iterator)]) 161 162 @combinations.generate( 163 combinations.combine( 164 distribution=strategy_combinations.all_strategies, mode=["eager"])) 165 def testGetNextAsOptionalExampleUsage(self, distribution): 166 global_batch_size = 2 167 steps_per_loop = 6 168 dataset = dataset_ops.Dataset.range( 169 8, output_type=dtypes.int32).batch(global_batch_size) 170 distributed_iterator = iter( 171 distribution.experimental_distribute_dataset(dataset)) 172 173 @def_function.function 174 def train_fn(distributed_iterator): 175 176 def step_fn(x): 177 return x 178 179 for _ in math_ops.range(steps_per_loop): 180 optional_data = distributed_iterator.get_next_as_optional() 181 if not optional_data.has_value(): 182 break 183 distribution.run(step_fn, args=(optional_data.get_value(),)) 184 185 train_fn(distributed_iterator) 186 187 @combinations.generate( 188 combinations.combine( 189 distribution=strategy_combinations.tpu_strategies, mode=["eager"])) 190 def testFullEagerTPU(self, distribution): 191 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 192 193 def train_step(data): 194 return math_ops.square(data) 195 196 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 197 198 with self.assertRaisesRegex(NotImplementedError, 199 "does not support pure eager execution"): 200 distribution.run(train_step, args=(next(input_iterator),)) 201 202 @combinations.generate( 203 combinations.combine( 204 distribution=strategy_combinations.all_strategies, 205 mode=["eager"] 206 )) 207 def testStepInFunction(self, distribution): 208 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 209 210 @def_function.function 211 def train_step(data): 212 return math_ops.square(data) 213 214 dist_dataset = distribution.experimental_distribute_dataset(dataset) 215 results = [] 216 for x in dist_dataset: 217 output = distribution.experimental_local_results( 218 distribution.run(train_step, args=(x,))) 219 results.append(output) 220 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 221 222 @combinations.generate( 223 combinations.combine( 224 distribution=strategy_combinations.all_strategies, 225 mode=["eager"] 226 )) 227 def testRunInFunction(self, distribution): 228 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 229 230 def train_step(data): 231 return math_ops.square(data) 232 233 @def_function.function 234 def f_train_step(input_data): 235 return distribution.experimental_local_results( 236 distribution.run(train_step, args=(input_data,))) 237 238 dist_dataset = distribution.experimental_distribute_dataset(dataset) 239 results = [] 240 for x in dist_dataset: 241 output = f_train_step(x) 242 results.append(output) 243 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 244 245 @combinations.generate( 246 combinations.combine( 247 distribution=[ 248 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 249 strategy_combinations.tpu_strategy, 250 strategy_combinations.tpu_strategy_packed_var, 251 ], 252 mode=["eager"])) 253 def testNestedOutput(self, distribution): 254 dataset = get_dataset_from_tensor_slices([0, 1, 2, 3]).batch(2) 255 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 256 257 @def_function.function 258 def run(iterator): 259 260 def computation(x): 261 return [{ 262 "a": x - 1, 263 "b": x + 1 264 }] 265 266 inputs = next(iterator) 267 outputs = distribution.run(computation, args=(inputs,)) 268 return nest.map_structure(distribution.experimental_local_results, 269 outputs) 270 271 results = run(input_iterator) 272 for replica in range(distribution.num_replicas_in_sync): 273 # The input dataset is range(4), so the replica id is same as input. 274 self.assertAllEqual(results[0]["a"][replica], [replica - 1]) 275 self.assertAllEqual(results[0]["b"][replica], [replica + 1]) 276 277 @combinations.generate( 278 combinations.combine( 279 distribution=strategy_combinations.all_strategies, 280 mode=["eager"] 281 )) 282 def testRunInFunctionAutoGraphApplication(self, distribution): 283 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 284 285 def train_step(data): 286 return math_ops.square(data) 287 288 @def_function.function 289 def f_train_step(input_data): 290 return distribution.experimental_local_results( 291 distribution.run(train_step, args=(input_data,))) 292 293 dist_dataset = distribution.experimental_distribute_dataset(dataset) 294 results = [] 295 for x in dist_dataset: 296 output = f_train_step(x) 297 results.append(output) 298 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 299 300 @combinations.generate( 301 combinations.combine( 302 distribution=strategy_combinations.all_strategies, 303 mode=["eager"] 304 )) 305 def testDatasetIterationInFunction(self, distribution): 306 with distribution.scope(): 307 a = variables.Variable( 308 1.0, aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) 309 310 def train_step(_): 311 a.assign_add(1.0) 312 313 @def_function.function 314 def f_train_step(dist_dataset): 315 number_of_steps = constant_op.constant(0.0) 316 product_of_means = constant_op.constant(2.0) 317 for x in dist_dataset: # loop with values modified each iteration 318 number_of_steps += 1 319 product_of_means *= math_ops.cast( 320 distribution.reduce("MEAN", x, axis=0), product_of_means.dtype) 321 322 for y in dist_dataset: # loop with no intermediate state 323 distribution.run(train_step, args=(y,)) 324 325 return number_of_steps, product_of_means 326 327 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 328 dist_dataset = distribution.experimental_distribute_dataset(dataset) 329 330 number_of_steps, product_of_means = f_train_step(dist_dataset) 331 self.assertEqual(2, number_of_steps.numpy()) 332 self.assertNear((2 * (5+6)/2 * (7+8)/2), product_of_means.numpy(), 1e-3) 333 334 # We set the initial value of `a` to 1 and iterate through the dataset 2 335 # times(4/2 where 4 is the number of dataset elements and 2 is the batch 336 # size). Hence the final result is 3. 337 self.assertEqual(3.0, (a.numpy())) 338 339 @combinations.generate( 340 combinations.combine( 341 distribution=strategy_combinations.all_strategies, 342 mode=["eager"] 343 )) 344 def testDatasetAssertWithDynamicBatch(self, distribution): 345 # Regression test for github issue 33517. 346 def step_fn(data): 347 assert_op = control_flow_ops.Assert(math_ops.less_equal( 348 math_ops.reduce_max(data), 100.), [data]) 349 with ops.control_dependencies([assert_op]): 350 return math_ops.square(data) 351 352 @def_function.function 353 def train(dataset): 354 results = [] 355 iterator = iter(dataset) 356 # we iterate through the loop 5 times since we have 3 elements and a 357 # global batch of 2. 358 for _ in range(2): 359 elem = next(iterator) 360 output = distribution.experimental_local_results( 361 distribution.run(step_fn, args=(elem,))) 362 results.append(output) 363 return results 364 365 dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7.,]).batch(2) 366 # TODO(b/138326910): Remove Dataset V1 version once bug resolved. 367 if not tf2.enabled(): 368 dataset = dataset_ops.Dataset.from_tensor_slices([5., 6., 7.,]).batch(2) 369 dist_dataset = distribution.experimental_distribute_dataset(dataset) 370 results = train(dist_dataset) 371 372 expected_results = [[25., 36.], [49.]] 373 self.assertEqual(len(expected_results), len(results)) 374 375 # Need to expand results since output will be grouped differently depending 376 # on the number of replicas. 377 for i, expected_result in enumerate(expected_results): 378 final_result = [] 379 actual_result = results[i] 380 for val in actual_result: 381 final_result.extend(val.numpy()) 382 self.assertAllEqual(expected_result, final_result) 383 384 @combinations.generate( 385 combinations.combine( 386 distribution=strategy_combinations.all_strategies, 387 mode=["eager"] 388 )) 389 def testDistributeDatasetIteratorWithoutFunction(self, distribution): 390 data = [5., 6., 7., 8.] 391 input_iterator = iter( 392 distribution.distribute_datasets_from_function( 393 lambda _: get_dataset_from_tensor_slices(data))) 394 395 self.assertAllEqual( 396 distribution.experimental_local_results(input_iterator.get_next()), 397 data[0:distribution.num_replicas_in_sync]) 398 399 @combinations.generate( 400 combinations.combine( 401 distribution=strategy_combinations.multidevice_strategies, 402 mode=["eager"] 403 )) 404 def testDistributeDatasetIteratorWithFunction(self, distribution): 405 data = [5., 6., 7., 8.] 406 input_iterator = iter( 407 distribution.distribute_datasets_from_function( 408 lambda _: get_dataset_from_tensor_slices(data))) 409 410 @def_function.function 411 def run(iterator): 412 return distribution.experimental_local_results(iterator.get_next()) 413 414 local_results = run(input_iterator) 415 self.assertAllEqual(local_results, 416 data[0:distribution.num_replicas_in_sync]) 417 backing_devices = [result.backing_device for result in local_results] 418 self.assertAllEqual(backing_devices, distribution.extended.worker_devices) 419 420 @combinations.generate( 421 combinations.combine( 422 distribution=strategy_combinations.multidevice_strategies, 423 mode=["eager"] 424 )) 425 def testDistributeDatasetPrefetch(self, distribution): 426 data = [5., 6., 7., 8.] 427 input_iterator = iter( 428 distribution.experimental_distribute_dataset( 429 get_dataset_from_tensor_slices(data).batch(2))) 430 431 local_results = distribution.experimental_local_results( 432 input_iterator.get_next()) 433 434 backing_devices = [result.backing_device for result in local_results] 435 self.assertAllEqual(backing_devices, distribution.extended.worker_devices) 436 437 @combinations.generate( 438 combinations.combine( 439 distribution=strategy_combinations.multidevice_strategies, 440 mode=["eager"] 441 )) 442 def testDistributeDatasetFunctionPrefetch(self, distribution): 443 data = [5., 6., 7., 8.] 444 input_iterator = iter( 445 distribution.distribute_datasets_from_function( 446 lambda _: get_dataset_from_tensor_slices(data))) 447 448 local_results = distribution.experimental_local_results( 449 input_iterator.get_next()) 450 451 backing_devices = [result.backing_device for result in local_results] 452 self.assertAllEqual(backing_devices, distribution.extended.worker_devices) 453 454 @combinations.generate( 455 combinations.combine( 456 distribution=strategy_combinations.tpu_strategies, 457 mode=["eager"] 458 )) 459 def testDistributeDatasetHostPrefetch(self, distribution): 460 data = [5., 6., 7., 8.] 461 input_iterator = iter( 462 distribution.experimental_distribute_dataset( 463 get_dataset_from_tensor_slices(data).batch(2), 464 distribute_lib.InputOptions(experimental_fetch_to_device=False))) 465 466 local_results = distribution.experimental_local_results( 467 input_iterator.get_next()) 468 469 for result in local_results: 470 self.assertEqual(result.backing_device, 471 device_util.resolve("/device:CPU:0")) 472 473 @combinations.generate( 474 combinations.combine( 475 distribution=strategy_combinations.tpu_strategies, 476 mode=["eager"] 477 )) 478 def testDistributeDatasetFunctionHostPrefetch(self, distribution): 479 data = [5., 6., 7., 8.] 480 input_iterator = iter( 481 distribution.distribute_datasets_from_function( 482 lambda _: get_dataset_from_tensor_slices(data), 483 distribute_lib.InputOptions(experimental_fetch_to_device=False))) 484 485 local_results = distribution.experimental_local_results( 486 input_iterator.get_next()) 487 488 for result in local_results: 489 self.assertEqual(result.backing_device, 490 device_util.resolve("/device:CPU:0")) 491 492 @combinations.generate( 493 combinations.combine( 494 distribution=strategy_combinations.multidevice_strategies, 495 mode=["eager"] 496 )) 497 def testDynamicShapes(self, distribution): 498 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 499 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 500 501 @def_function.function 502 def run(iterator): 503 def computation(x): 504 return math_ops.reduce_mean(x) 505 inputs = next(iterator) 506 outputs = distribution.experimental_local_results( 507 distribution.run(computation, args=(inputs,))) 508 return outputs 509 510 # This assumes that there are exactly 2 replicas 511 self.assertAllEqual([5.5, 7.], run(input_iterator)) 512 513 @combinations.generate( 514 combinations.combine( 515 distribution=strategy_combinations.tpu_strategy, mode=["eager"])) 516 def testDynamicShapesWithRunOptionsBucketizing(self, distribution): 517 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 518 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 519 options = distribute_lib.RunOptions( 520 experimental_bucketizing_dynamic_shape=True) 521 522 @def_function.function 523 def run(iterator): 524 525 def computation(x): 526 return math_ops.reduce_mean(x) 527 528 inputs = next(iterator) 529 outputs = distribution.experimental_local_results( 530 distribution.run( 531 computation, args=(inputs,), options=options)) 532 return outputs 533 534 # This assumes that there are exactly 2 replicas 535 self.assertAllEqual([5.5, 7.], run(input_iterator)) 536 537 @combinations.generate( 538 combinations.combine( 539 distribution=strategy_combinations.tpu_strategy, mode=["eager"])) 540 def testDynamicShapesWithRunOptionsDisableDynamicPadder(self, distribution): 541 dataset = get_dataset_from_tensor_slices([5, 6, 7]).batch(4) 542 mask_dataset = get_dataset_from_tensor_slices([1, 0, 1]).batch(4) 543 dataset = dataset_ops.DatasetV2.zip((dataset, mask_dataset)) 544 545 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 546 options = distribute_lib.RunOptions( 547 experimental_xla_options=tpu.XLAOptions( 548 enable_xla_dynamic_padder=False)) 549 550 @def_function.function 551 def run(iterator): 552 553 def computation(inputs): 554 x, mask = inputs 555 y = x * mask 556 return math_ops.reduce_sum(y) 557 558 inputs = next(iterator) 559 outputs = distribution.experimental_local_results( 560 distribution.run(computation, args=(inputs,), options=options)) 561 return outputs 562 563 # This assumes that there are exactly 2 replicas 564 self.assertAllEqual([5, 7], run(input_iterator)) 565 566 @combinations.generate( 567 combinations.combine( 568 distribution=strategy_combinations.multidevice_strategies, 569 mode=["eager"])) 570 def testDynamicOutputsWithX64(self, distribution): 571 dataset = get_dataset_from_tensor_slices( 572 [5]).map(lambda x: math_ops.cast(x, dtypes.int64)).batch(2) 573 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 574 575 @def_function.function 576 def run(iterator): 577 578 def computation(x): 579 return math_ops.add(x, x) 580 581 inputs = next(iterator) 582 outputs = distribution.experimental_local_results( 583 distribution.run(computation, args=(inputs,))) 584 return outputs 585 586 # This assumes that there are exactly 2 replicas 587 result = run(input_iterator) 588 self.assertAllEqual([10], result[0]) 589 self.assertAllEqual([], result[1]) 590 591 @combinations.generate( 592 combinations.combine( 593 distribution=strategy_combinations.multidevice_strategies, 594 mode=["eager"] 595 )) 596 def testDynamicShapesWithGetNextOutsideFunction(self, distribution): 597 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 598 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 599 600 @def_function.function 601 def run(inputs): 602 def computation(x): 603 return math_ops.reduce_mean(x) 604 outputs = distribution.experimental_local_results( 605 distribution.run(computation, args=(inputs,))) 606 return outputs 607 608 # This assumes that there are exactly 2 replicas 609 self.assertAllEqual([5.5, 7.], run(next(input_iterator))) 610 611 @combinations.generate( 612 combinations.combine( 613 distribution=strategy_combinations.multidevice_strategies, 614 mode=["eager"] 615 )) 616 def testStrategyReduceWithDynamicShapes(self, distribution): 617 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 618 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 619 620 @def_function.function 621 def run(iterator): 622 inputs = next(iterator) 623 return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0) 624 625 self.assertAllEqual(6., run(input_iterator)) 626 627 @combinations.generate( 628 combinations.combine( 629 distribution=strategy_combinations.multidevice_strategies, 630 mode=["eager"] 631 )) 632 def testStrategyReduceWithDynamicShapesRank2(self, distribution): 633 dataset = get_dataset_from_tensor_slices( 634 [[1., 1.], [1., 1.], [1., 1.]]).batch(4) 635 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 636 637 @def_function.function 638 def run(iterator): 639 inputs = next(iterator) 640 return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0) 641 642 self.assertAllEqual([1., 1.], run(input_iterator)) 643 644 @combinations.generate( 645 combinations.combine( 646 distribution=strategy_combinations.multidevice_strategies, 647 mode=["eager"] 648 )) 649 def testDynamicShapesWithSizeOp(self, distribution): 650 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 651 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 652 653 @def_function.function 654 def run(inputs): 655 def computation(x): 656 return array_ops.size_v2(x) 657 outputs = distribution.experimental_local_results( 658 distribution.run(computation, args=(inputs,))) 659 return outputs 660 661 # This assumes that there are exactly 2 replicas 662 self.assertAllEqual([2, 1], run(next(input_iterator))) 663 664 @combinations.generate( 665 combinations.combine( 666 distribution=strategy_combinations.multidevice_strategies, 667 mode=["eager"])) 668 def testSegmentSumWithDynamicNumberOfSegments(self, distribution): 669 670 def dataset_fn(_): 671 data = array_ops.zeros(5, dtype=dtypes.int32) 672 dataset = get_dataset_from_tensor_slices(data) 673 dataset = dataset.batch(3) 674 return dataset 675 676 input_iterator = iter( 677 distribution.distribute_datasets_from_function(dataset_fn)) 678 679 @def_function.function 680 def step_fn(example): 681 segment_ids = array_ops.zeros_like_v2(example) 682 num_segment = array_ops.shape(example)[0] 683 # If number of segments is dynamic, output should be a dynamic shape. 684 return math_ops.unsorted_segment_sum(example, segment_ids, num_segment) 685 686 # This assumes that there are exactly 2 replicas 687 outputs = distribution.experimental_local_results( 688 distribution.run(step_fn, args=(next(input_iterator),))) 689 self.assertAllEqual((3,), outputs[0].shape) 690 self.assertAllEqual((2,), outputs[1].shape) 691 692 @combinations.generate( 693 combinations.combine( 694 distribution=strategy_combinations.multidevice_strategies, 695 mode=["eager"])) 696 def testReshapeWithDynamicInputs(self, distribution): 697 698 def dataset_fn(_): 699 data = array_ops.zeros((5, 1, 2), dtype=dtypes.int32) 700 dataset = get_dataset_from_tensor_slices(data) 701 dataset = dataset.batch(3) 702 return dataset 703 704 input_iterator = iter( 705 distribution.distribute_datasets_from_function(dataset_fn)) 706 707 @def_function.function 708 def step_fn(example): 709 # example: [<=3, 1, 2] 710 # tile: [<=3, <=3, 2] 711 tile = array_ops.tile(example, [1, array_ops.shape(example)[0], 1]) 712 # reshape1: [<=(3*3 = 9), 2] 713 reshape1 = array_ops.reshape(tile, [-1, 2]) 714 715 # reshape2: [<=3, <=3, 2] 716 reshape2 = array_ops.reshape( 717 reshape1, 718 [array_ops.shape(example)[0], 719 array_ops.shape(example)[0], 2]) 720 721 # reshape3: [<=3, -1, 2] 722 reshape3 = array_ops.reshape(reshape1, 723 [array_ops.shape(example)[0], -1, 2]) 724 # reshape4: [-1, <=3, 2] 725 reshape4 = array_ops.reshape(reshape1, 726 [-1, array_ops.shape(example)[0], 2]) 727 # Reshape1 is duplicated in order to test dynamic dimension on copies. 728 return [reshape1, reshape2, reshape3, reshape4, reshape1] 729 730 # This assumes that there are exactly 2 replicas 731 outputs = distribution.experimental_local_results( 732 distribution.run(step_fn, args=(next(input_iterator),))) 733 self.assertAllEqual((9, 2), outputs[0][0].shape) 734 self.assertAllEqual((3, 3, 2), outputs[0][1].shape) 735 self.assertAllEqual((3, 3, 2), outputs[0][2].shape) 736 self.assertAllEqual((3, 3, 2), outputs[0][3].shape) 737 self.assertAllEqual((9, 2), outputs[0][4].shape) 738 739 self.assertAllEqual((4, 2), outputs[1][0].shape) 740 self.assertAllEqual((2, 2, 2), outputs[1][1].shape) 741 self.assertAllEqual((2, 2, 2), outputs[1][2].shape) 742 self.assertAllEqual((2, 2, 2), outputs[1][3].shape) 743 self.assertAllEqual((4, 2), outputs[1][4].shape) 744 745 @combinations.generate( 746 combinations.combine( 747 distribution=strategy_combinations.multidevice_strategies, 748 mode=["eager"])) 749 def testDynamicShapesWithFirstReplicaNotMaximumShape(self, distribution): 750 def dataset_fn(_): 751 dataset1 = get_dataset_from_tensor_slices([[1., 2.], [1., 2.]]) 752 dataset2 = get_dataset_from_tensor_slices([[1., 2., 3.], 753 [1., 2., 3.]]) 754 dataset = dataset1.concatenate(dataset2) 755 dataset = dataset.batch(2, drop_remainder=True) 756 return dataset 757 758 input_iterator = iter( 759 distribution.distribute_datasets_from_function(dataset_fn)) 760 761 @def_function.function 762 def run(inputs): 763 def computation(x): 764 return math_ops.reduce_mean(x) 765 outputs = distribution.experimental_local_results( 766 distribution.run(computation, args=(inputs,))) 767 return outputs 768 769 # This assumes that there are exactly 2 replicas 770 self.assertAllEqual([1.5, 2.], run(next(input_iterator))) 771 772 @combinations.generate( 773 combinations.combine( 774 distribution=strategy_combinations.multidevice_strategies, 775 mode=["eager"])) 776 def testMapFnWithDynamicInputs(self, distribution): 777 778 def dataset_fn(_): 779 data = array_ops.zeros((20, 300, 32), dtype=dtypes.int32) 780 dataset = get_dataset_from_tensor_slices(data) 781 dataset = dataset.batch(16) 782 return dataset 783 784 input_iterator = iter( 785 distribution.distribute_datasets_from_function(dataset_fn)) 786 787 def embedding_lookup(inputs): 788 embedding_weights = array_ops.zeros((1, 128)) 789 flat_inputs = array_ops.reshape(inputs, [-1]) 790 embeddings = array_ops.gather(embedding_weights, flat_inputs) 791 embeddings = array_ops.reshape(embeddings, inputs.shape.as_list() + [128]) 792 return embeddings 793 794 @def_function.function 795 def step_fn(example): 796 return map_fn.map_fn( 797 embedding_lookup, example, fn_output_signature=dtypes.float32) 798 799 # This assumes that there are exactly 2 replicas 800 outputs = distribution.experimental_local_results( 801 distribution.run(step_fn, args=(next(input_iterator),))) 802 self.assertAllEqual((16, 300, 32, 128), outputs[0].shape) 803 self.assertAllEqual((4, 300, 32, 128), outputs[1].shape) 804 805 @combinations.generate( 806 combinations.combine( 807 distribution=strategy_combinations.all_strategies, 808 mode=["eager"] 809 )) 810 def testDatasetDistributeEvenlyDivisibleDrop(self, distribution): 811 # If the batch size is evenly divisible by the number of workers and we set 812 # drop_remainder=True on the dataset, then DistributedIterator will use a 813 # different (and more efficient) code path which avoids some control flow 814 # ops. 815 dataset = get_dataset_from_tensor_slices([5., 6.]).batch( 816 2, drop_remainder=True) 817 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 818 819 data = next(input_iterator) 820 821 expected_result = [5., 6.] 822 final_result = [] 823 actual_result = distribution.experimental_local_results(data) 824 for val in actual_result: 825 final_result.extend(val) 826 self.assertAllEqual(expected_result, final_result) 827 828 @combinations.generate( 829 combinations.combine( 830 distribution=strategy_combinations.all_strategies, 831 mode=["eager"] 832 )) 833 def testDatasetDistributeNotDivisibleDrop(self, distribution): 834 # If each batch is not evenly divisible by the number of workers, 835 # the remainder will be dropped. 836 dataset = get_dataset_from_tensor_slices([5., 6.]).batch( 837 1, drop_remainder=True) 838 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 839 840 data = next(input_iterator) 841 842 expected_result = [5.] 843 final_result = [] 844 actual_result = distribution.experimental_local_results(data) 845 for val in actual_result: 846 final_result.extend(val) 847 self.assertAllEqual(expected_result, final_result) 848 849 @combinations.generate( 850 combinations.combine( 851 distribution=strategy_combinations.all_strategies, 852 mode=["eager"] 853 )) 854 def testDatasetDistributeEvenlyDivisibleNoDrop(self, distribution): 855 # Setting drop_remainder=False on the dataset causes DistributedIterator 856 # to use get_next_as_optional(), even if the batched dataset is evenly 857 # divisible by the number of workers. 858 dataset = get_dataset_from_tensor_slices([5., 6.]).batch( 859 2, drop_remainder=False) 860 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 861 862 data = next(input_iterator) 863 864 expected_result = [5., 6.] 865 final_result = [] 866 actual_result = distribution.experimental_local_results(data) 867 for val in actual_result: 868 final_result.extend(val) 869 self.assertAllEqual(expected_result, final_result) 870 871 @combinations.generate( 872 combinations.combine( 873 distribution=strategy_combinations.all_strategies, 874 mode=["eager"] 875 )) 876 def testDatasetPartialBatchWithMixedOutputs(self, distribution): 877 # Dynamic output size with a mix of static and dynamic outputs 878 dataset = get_dataset_from_tensor_slices([5.]).batch(2) 879 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 880 881 @def_function.function 882 def run(iterator): 883 884 def computation(x): 885 # Fixed size output with a dynamic sized output. 886 return array_ops.zeros([3]), math_ops.square(x) 887 888 return distribution.run( 889 computation, args=(next(iterator),)) 890 891 results = run(input_iterator) 892 893 # First result is fixed for all replicas. 894 for replica_id in range(distribution.num_replicas_in_sync): 895 self.assertAllEqual([0., 0., 0.], 896 distribution.experimental_local_results( 897 results[0])[replica_id]) 898 # Only first replica has distributed dataset computation. 899 self.assertAllEqual([25.], 900 distribution.experimental_local_results(results[1])[0]) 901 # Other replicas have no distributed dataset computation. 902 for replica_id in range(1, distribution.num_replicas_in_sync): 903 self.assertAllEqual([], 904 distribution.experimental_local_results( 905 results[1])[replica_id]) 906 907 @combinations.generate( 908 combinations.combine( 909 distribution=strategy_combinations.all_strategies, 910 mode=["eager"] 911 )) 912 def testIterationInsideFunction(self, distribution): 913 914 def step_fn(data): 915 return math_ops.square(data) 916 917 @def_function.function 918 def train(dataset): 919 results = [] 920 iterator = iter(dataset) 921 # we iterate through the loop 2 times since we have 4 elements and a 922 # global batch of 2. 923 for _ in range(2): 924 elem = next(iterator) 925 output = distribution.experimental_local_results( 926 distribution.run(step_fn, args=(elem,))) 927 results.append(output) 928 return results 929 930 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 931 dist_dataset = distribution.experimental_distribute_dataset(dataset) 932 results = train(dist_dataset) 933 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 934 935 @combinations.generate( 936 combinations.combine( 937 distribution=strategy_combinations.all_strategies, 938 mode=["eager"] 939 )) 940 def testIterationOutsideFunction(self, distribution): 941 942 def train_step(data): 943 return math_ops.square(data) 944 945 @def_function.function 946 def f_train_step(input_data): 947 return distribution.experimental_local_results( 948 distribution.run(train_step, args=(input_data,))) 949 950 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 951 dist_dataset = distribution.experimental_distribute_dataset(dataset) 952 iterator = iter(dist_dataset) 953 results = [] 954 # we iterate through the loop 2 times since we have 4 elements and a 955 # global batch of 2. 956 for _ in range(2): 957 output = f_train_step(next(iterator)) 958 results.append(output) 959 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 960 961 @combinations.generate( 962 combinations.combine( 963 distribution=strategy_combinations.all_strategies, 964 mode=["eager"] 965 )) 966 def testMultiDeviceDataCapturedFunction(self, distribution): 967 inputs = constant_op.constant([2., 3.]) 968 dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5) 969 input_iterator = iter( 970 distribution.distribute_datasets_from_function(dataset)) 971 with distribution.scope(): 972 var = variables.Variable(1.0) 973 974 @def_function.function 975 def train_step(input_iterator): 976 977 def func(inputs): 978 return math_ops.square(inputs) + var 979 980 per_replica_outputs = distribution.run( 981 func, (next(input_iterator),)) 982 mean = distribution.reduce( 983 reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None) 984 for _ in dataset_ops.Dataset.range(1): 985 per_replica_outputs = distribution.run( 986 func, (next(input_iterator),)) 987 mean = distribution.reduce( 988 reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None) 989 return mean 990 991 with distribution.scope(): 992 if distribution.num_replicas_in_sync == 1: 993 self.assertAlmostEqual(10.0, self.evaluate(train_step(input_iterator))) 994 else: 995 self.assertAlmostEqual(7.5, self.evaluate(train_step(input_iterator))) 996 997 @combinations.generate( 998 combinations.combine( 999 distribution=strategy_combinations.all_strategies, 1000 mode=["eager"] 1001 )) 1002 def testDatasetOutOfRange(self, distribution): 1003 with distribution.scope(): 1004 a = variables.Variable( 1005 0.0, aggregation=variables.VariableAggregation.SUM) 1006 1007 def train_step(val): 1008 a.assign_add(math_ops.reduce_sum(val)) 1009 1010 @def_function.function 1011 def f_train_step(iterator): 1012 distribution.run(train_step, args=(next(iterator),)) 1013 return a 1014 1015 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 1016 dist_dataset = distribution.experimental_distribute_dataset(dataset) 1017 1018 iterator = iter(dist_dataset) 1019 with self.assertRaises(errors.OutOfRangeError): 1020 for _ in range(100): 1021 f_train_step(iterator) 1022 1023 self.assertAlmostEqual(26.0, a.numpy()) 1024 1025 @combinations.generate( 1026 combinations.combine( 1027 distribution=strategy_combinations.multidevice_strategies, 1028 mode=["eager"])) 1029 def testComputeLossWithDynamicShapes(self, distribution): 1030 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 1031 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 1032 1033 @def_function.function 1034 def run(iterator): 1035 1036 def computation(x): 1037 return losses.compute_weighted_loss(x, weights=array_ops.ones_like(x)) 1038 1039 inputs = next(iterator) 1040 outputs = distribution.experimental_local_results( 1041 distribution.run(computation, args=(inputs,))) 1042 return outputs 1043 1044 # This assumes that there are exactly 2 replicas 1045 self.assertAllEqual([5.5, 7.], run(input_iterator)) 1046 1047 1048if __name__ == "__main__": 1049 test_util.main() 1050