1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for custom training loops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22 23from tensorflow.python import tf2 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.distribute import combinations 26from tensorflow.python.distribute import device_util 27from tensorflow.python.distribute import distribute_lib 28from tensorflow.python.distribute import reduce_util 29from tensorflow.python.distribute import strategy_combinations 30from tensorflow.python.distribute import test_util 31from tensorflow.python.eager import def_function 32from tensorflow.python.eager import test 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import errors 36from tensorflow.python.framework import ops 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import control_flow_ops 39from tensorflow.python.ops import map_fn 40from tensorflow.python.ops import math_ops 41from tensorflow.python.ops import variables 42from tensorflow.python.util import nest 43 44 45def get_dataset_from_tensor_slices(inp_array): 46 dataset = dataset_ops.DatasetV2.from_tensor_slices(inp_array) 47 # TODO(b/138326910): Remove Dataset V1 version once bug resolved. 48 if not tf2.enabled(): 49 dataset = dataset_ops.Dataset.from_tensor_slices(inp_array) 50 return dataset 51 52 53class AssertFlattenedMixin(object): 54 """Mixin for specialized asserts.""" 55 56 def assert_equal_flattened(self, expected_results, actual_results): 57 """Asserts that flattened results are equal. 58 59 Due to the number of replicas in the strategy, the output may have a 60 different structure and needs to be flattened for comparison. 61 62 Args: 63 expected_results: The results expected as a result of a computation. 64 actual_results: The actual results of a computation. 65 """ 66 self.assertEqual(len(expected_results), len(actual_results)) 67 68 for i, expected_result in enumerate(expected_results): 69 final_result = [] 70 actual_result = actual_results[i] 71 for val in actual_result: 72 final_result.extend(val.numpy()) 73 self.assertAllEqual(expected_result, final_result) 74 75 76class InputIterationTest(test.TestCase, parameterized.TestCase, 77 AssertFlattenedMixin): 78 79 @combinations.generate( 80 combinations.combine( 81 distribution=strategy_combinations.all_strategies, 82 mode=["eager"] 83 )) 84 def testConstantNumpyInput(self, distribution): 85 86 @def_function.function 87 def run(x): 88 89 def computation(x): 90 return math_ops.square(x) 91 92 outputs = distribution.experimental_local_results( 93 distribution.run(computation, args=(x,))) 94 return outputs 95 96 self.assertAllEqual( 97 constant_op.constant(4., shape=(distribution.num_replicas_in_sync)), 98 run(2.)) 99 100 @combinations.generate( 101 combinations.combine( 102 distribution=strategy_combinations.all_strategies, 103 mode=["eager"] 104 )) 105 def testStatefulExperimentalRunAlwaysExecute(self, distribution): 106 with distribution.scope(): 107 v = variables.Variable( 108 0.0, aggregation=variables.VariableAggregation.MEAN) 109 110 @def_function.function 111 def train_step(): 112 113 def assign_add(): 114 v.assign_add(1.0) 115 116 distribution.run(assign_add) 117 return array_ops.zeros([]) 118 119 train_step() 120 self.assertAllEqual(1.0, v.numpy()) 121 122 @combinations.generate( 123 combinations.combine( 124 distribution=strategy_combinations.strategies_minus_tpu, 125 mode=["eager"])) 126 def testFullEager(self, distribution): 127 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 128 129 def train_step(data): 130 return math_ops.square(data) 131 132 dist_dataset = distribution.experimental_distribute_dataset(dataset) 133 results = [] 134 for x in dist_dataset: 135 output = distribution.experimental_local_results( 136 distribution.run(train_step, args=(x,))) 137 results.append(output) 138 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 139 140 @combinations.generate( 141 combinations.combine( 142 distribution=strategy_combinations.all_strategies, mode=["eager"])) 143 def testGetNextAsOptional(self, distribution): 144 data = [5., 6., 7., 8.] 145 dataset = get_dataset_from_tensor_slices(data).batch(2) 146 dist_dataset = distribution.experimental_distribute_dataset(dataset) 147 iterator = iter(dist_dataset) 148 149 def train_step(data): 150 return math_ops.square(data) 151 152 @def_function.function 153 def run(iterator): 154 return distribution.experimental_local_results( 155 distribution.run( 156 train_step, args=(iterator.get_next_as_optional().get_value(),))) 157 158 self.assert_equal_flattened([[25., 36.]], [run(iterator)]) 159 160 @combinations.generate( 161 combinations.combine( 162 distribution=strategy_combinations.all_strategies, mode=["eager"])) 163 def testGetNextAsOptionalExampleUsage(self, distribution): 164 global_batch_size = 2 165 steps_per_loop = 6 166 dataset = dataset_ops.Dataset.range( 167 8, output_type=dtypes.int32).batch(global_batch_size) 168 distributed_iterator = iter( 169 distribution.experimental_distribute_dataset(dataset)) 170 171 @def_function.function 172 def train_fn(distributed_iterator): 173 174 def step_fn(x): 175 return x 176 177 for _ in math_ops.range(steps_per_loop): 178 optional_data = distributed_iterator.get_next_as_optional() 179 if not optional_data.has_value(): 180 break 181 distribution.run(step_fn, args=(optional_data.get_value(),)) 182 183 train_fn(distributed_iterator) 184 185 @combinations.generate( 186 combinations.combine( 187 distribution=strategy_combinations.tpu_strategies, mode=["eager"])) 188 def testFullEagerTPU(self, distribution): 189 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 190 191 def train_step(data): 192 return math_ops.square(data) 193 194 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 195 196 with self.assertRaisesRegex(NotImplementedError, 197 "does not support pure eager execution"): 198 distribution.run(train_step, args=(next(input_iterator),)) 199 200 @combinations.generate( 201 combinations.combine( 202 distribution=strategy_combinations.all_strategies, 203 mode=["eager"] 204 )) 205 def testStepInFunction(self, distribution): 206 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 207 208 @def_function.function 209 def train_step(data): 210 return math_ops.square(data) 211 212 dist_dataset = distribution.experimental_distribute_dataset(dataset) 213 results = [] 214 for x in dist_dataset: 215 output = distribution.experimental_local_results( 216 distribution.run(train_step, args=(x,))) 217 results.append(output) 218 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 219 220 @combinations.generate( 221 combinations.combine( 222 distribution=strategy_combinations.all_strategies, 223 mode=["eager"] 224 )) 225 def testRunInFunction(self, distribution): 226 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 227 228 def train_step(data): 229 return math_ops.square(data) 230 231 @def_function.function 232 def f_train_step(input_data): 233 return distribution.experimental_local_results( 234 distribution.run(train_step, args=(input_data,))) 235 236 dist_dataset = distribution.experimental_distribute_dataset(dataset) 237 results = [] 238 for x in dist_dataset: 239 output = f_train_step(x) 240 results.append(output) 241 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 242 243 @combinations.generate( 244 combinations.combine( 245 distribution=[ 246 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 247 strategy_combinations.tpu_strategy, 248 strategy_combinations.tpu_strategy_packed_var, 249 ], 250 mode=["eager"])) 251 def testNestedOutput(self, distribution): 252 dataset = get_dataset_from_tensor_slices([0, 1, 2, 3]).batch(2) 253 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 254 255 @def_function.function 256 def run(iterator): 257 258 def computation(x): 259 return [{ 260 "a": x - 1, 261 "b": x + 1 262 }] 263 264 inputs = next(iterator) 265 outputs = distribution.run(computation, args=(inputs,)) 266 return nest.map_structure(distribution.experimental_local_results, 267 outputs) 268 269 results = run(input_iterator) 270 for replica in range(distribution.num_replicas_in_sync): 271 # The input dataset is range(4), so the replica id is same as input. 272 self.assertAllEqual(results[0]["a"][replica], [replica - 1]) 273 self.assertAllEqual(results[0]["b"][replica], [replica + 1]) 274 275 @combinations.generate( 276 combinations.combine( 277 distribution=strategy_combinations.all_strategies, 278 mode=["eager"] 279 )) 280 def testRunInFunctionAutoGraphApplication(self, distribution): 281 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 282 283 def train_step(data): 284 return math_ops.square(data) 285 286 @def_function.function 287 def f_train_step(input_data): 288 return distribution.experimental_local_results( 289 distribution.run(train_step, args=(input_data,))) 290 291 dist_dataset = distribution.experimental_distribute_dataset(dataset) 292 results = [] 293 for x in dist_dataset: 294 output = f_train_step(x) 295 results.append(output) 296 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 297 298 @combinations.generate( 299 combinations.combine( 300 distribution=strategy_combinations.all_strategies, 301 mode=["eager"] 302 )) 303 def testDatasetIterationInFunction(self, distribution): 304 with distribution.scope(): 305 a = variables.Variable( 306 1.0, aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) 307 308 def train_step(_): 309 a.assign_add(1.0) 310 311 @def_function.function 312 def f_train_step(dist_dataset): 313 number_of_steps = constant_op.constant(0.0) 314 product_of_means = constant_op.constant(2.0) 315 for x in dist_dataset: # loop with values modified each iteration 316 number_of_steps += 1 317 product_of_means *= math_ops.cast( 318 distribution.reduce("MEAN", x, axis=0), product_of_means.dtype) 319 320 for y in dist_dataset: # loop with no intermediate state 321 distribution.run(train_step, args=(y,)) 322 323 return number_of_steps, product_of_means 324 325 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 326 dist_dataset = distribution.experimental_distribute_dataset(dataset) 327 328 number_of_steps, product_of_means = f_train_step(dist_dataset) 329 self.assertEqual(2, number_of_steps.numpy()) 330 self.assertNear((2 * (5+6)/2 * (7+8)/2), product_of_means.numpy(), 1e-3) 331 332 # We set the initial value of `a` to 1 and iterate through the dataset 2 333 # times(4/2 where 4 is the number of dataset elements and 2 is the batch 334 # size). Hence the final result is 3. 335 self.assertEqual(3.0, (a.numpy())) 336 337 @combinations.generate( 338 combinations.combine( 339 distribution=strategy_combinations.all_strategies, 340 mode=["eager"] 341 )) 342 def testDatasetAssertWithDynamicBatch(self, distribution): 343 # Regression test for github issue 33517. 344 def step_fn(data): 345 assert_op = control_flow_ops.Assert(math_ops.less_equal( 346 math_ops.reduce_max(data), 100.), [data]) 347 with ops.control_dependencies([assert_op]): 348 return math_ops.square(data) 349 350 @def_function.function 351 def train(dataset): 352 results = [] 353 iterator = iter(dataset) 354 # we iterate through the loop 5 times since we have 3 elements and a 355 # global batch of 2. 356 for _ in range(2): 357 elem = next(iterator) 358 output = distribution.experimental_local_results( 359 distribution.run(step_fn, args=(elem,))) 360 results.append(output) 361 return results 362 363 dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7.,]).batch(2) 364 # TODO(b/138326910): Remove Dataset V1 version once bug resolved. 365 if not tf2.enabled(): 366 dataset = dataset_ops.Dataset.from_tensor_slices([5., 6., 7.,]).batch(2) 367 dist_dataset = distribution.experimental_distribute_dataset(dataset) 368 results = train(dist_dataset) 369 370 expected_results = [[25., 36.], [49.]] 371 self.assertEqual(len(expected_results), len(results)) 372 373 # Need to expand results since output will be grouped differently depending 374 # on the number of replicas. 375 for i, expected_result in enumerate(expected_results): 376 final_result = [] 377 actual_result = results[i] 378 for val in actual_result: 379 final_result.extend(val.numpy()) 380 self.assertAllEqual(expected_result, final_result) 381 382 @combinations.generate( 383 combinations.combine( 384 distribution=strategy_combinations.all_strategies, 385 mode=["eager"] 386 )) 387 def testDistributeDatasetIteratorWithoutFunction(self, distribution): 388 data = [5., 6., 7., 8.] 389 input_iterator = iter( 390 distribution.distribute_datasets_from_function( 391 lambda _: get_dataset_from_tensor_slices(data))) 392 393 self.assertAllEqual( 394 distribution.experimental_local_results(input_iterator.get_next()), 395 data[0:distribution.num_replicas_in_sync]) 396 397 @combinations.generate( 398 combinations.combine( 399 distribution=strategy_combinations.multidevice_strategies, 400 mode=["eager"] 401 )) 402 def testDistributeDatasetIteratorWithFunction(self, distribution): 403 data = [5., 6., 7., 8.] 404 input_iterator = iter( 405 distribution.distribute_datasets_from_function( 406 lambda _: get_dataset_from_tensor_slices(data))) 407 408 @def_function.function 409 def run(iterator): 410 return distribution.experimental_local_results(iterator.get_next()) 411 412 local_results = run(input_iterator) 413 self.assertAllEqual(local_results, 414 data[0:distribution.num_replicas_in_sync]) 415 backing_devices = [result.backing_device for result in local_results] 416 self.assertAllEqual(backing_devices, distribution.extended.worker_devices) 417 418 @combinations.generate( 419 combinations.combine( 420 distribution=strategy_combinations.multidevice_strategies, 421 mode=["eager"] 422 )) 423 def testDistributeDatasetPrefetch(self, distribution): 424 data = [5., 6., 7., 8.] 425 input_iterator = iter( 426 distribution.experimental_distribute_dataset( 427 get_dataset_from_tensor_slices(data).batch(2))) 428 429 local_results = distribution.experimental_local_results( 430 input_iterator.get_next()) 431 432 backing_devices = [result.backing_device for result in local_results] 433 self.assertAllEqual(backing_devices, distribution.extended.worker_devices) 434 435 @combinations.generate( 436 combinations.combine( 437 distribution=strategy_combinations.multidevice_strategies, 438 mode=["eager"] 439 )) 440 def testDistributeDatasetFunctionPrefetch(self, distribution): 441 data = [5., 6., 7., 8.] 442 input_iterator = iter( 443 distribution.distribute_datasets_from_function( 444 lambda _: get_dataset_from_tensor_slices(data))) 445 446 local_results = distribution.experimental_local_results( 447 input_iterator.get_next()) 448 449 backing_devices = [result.backing_device for result in local_results] 450 self.assertAllEqual(backing_devices, distribution.extended.worker_devices) 451 452 @combinations.generate( 453 combinations.combine( 454 distribution=strategy_combinations.tpu_strategies, 455 mode=["eager"] 456 )) 457 def testDistributeDatasetHostPrefetch(self, distribution): 458 data = [5., 6., 7., 8.] 459 input_iterator = iter( 460 distribution.experimental_distribute_dataset( 461 get_dataset_from_tensor_slices(data).batch(2), 462 distribute_lib.InputOptions( 463 experimental_prefetch_to_device=False))) 464 465 local_results = distribution.experimental_local_results( 466 input_iterator.get_next()) 467 468 for result in local_results: 469 self.assertEqual(result.backing_device, 470 device_util.resolve("/device:CPU:0")) 471 472 @combinations.generate( 473 combinations.combine( 474 distribution=strategy_combinations.tpu_strategies, 475 mode=["eager"] 476 )) 477 def testDistributeDatasetFunctionHostPrefetch(self, distribution): 478 data = [5., 6., 7., 8.] 479 input_iterator = iter( 480 distribution.distribute_datasets_from_function( 481 lambda _: get_dataset_from_tensor_slices(data), 482 distribute_lib.InputOptions(experimental_prefetch_to_device=False))) 483 484 local_results = distribution.experimental_local_results( 485 input_iterator.get_next()) 486 487 for result in local_results: 488 self.assertEqual(result.backing_device, 489 device_util.resolve("/device:CPU:0")) 490 491 @combinations.generate( 492 combinations.combine( 493 distribution=strategy_combinations.multidevice_strategies, 494 mode=["eager"] 495 )) 496 def testDynamicShapes(self, distribution): 497 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 498 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 499 500 @def_function.function 501 def run(iterator): 502 def computation(x): 503 return math_ops.reduce_mean(x) 504 inputs = next(iterator) 505 outputs = distribution.experimental_local_results( 506 distribution.run(computation, args=(inputs,))) 507 return outputs 508 509 # This assumes that there are exactly 2 replicas 510 self.assertAllEqual([5.5, 7.], run(input_iterator)) 511 512 @combinations.generate( 513 combinations.combine( 514 distribution=strategy_combinations.multidevice_strategies, 515 mode=["eager"])) 516 def testDynamicShapesWithRunOptions(self, distribution): 517 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 518 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 519 options = distribute_lib.RunOptions 520 options.experimental_bucketizing_dynamic_shape = True 521 522 @def_function.function 523 def run(iterator): 524 525 def computation(x): 526 return math_ops.reduce_mean(x) 527 528 inputs = next(iterator) 529 outputs = distribution.experimental_local_results( 530 distribution.run( 531 computation, args=(inputs,), options=options)) 532 return outputs 533 534 # This assumes that there are exactly 2 replicas 535 self.assertAllEqual([5.5, 7.], run(input_iterator)) 536 537 @combinations.generate( 538 combinations.combine( 539 distribution=strategy_combinations.multidevice_strategies, 540 mode=["eager"])) 541 def testDynamicOutputsWithX64(self, distribution): 542 dataset = get_dataset_from_tensor_slices( 543 [5]).map(lambda x: math_ops.cast(x, dtypes.int64)).batch(2) 544 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 545 546 @def_function.function 547 def run(iterator): 548 549 def computation(x): 550 return math_ops.add(x, x) 551 552 inputs = next(iterator) 553 outputs = distribution.experimental_local_results( 554 distribution.run(computation, args=(inputs,))) 555 return outputs 556 557 # This assumes that there are exactly 2 replicas 558 result = run(input_iterator) 559 self.assertAllEqual([10], result[0]) 560 self.assertAllEqual([], result[1]) 561 562 @combinations.generate( 563 combinations.combine( 564 distribution=strategy_combinations.multidevice_strategies, 565 mode=["eager"] 566 )) 567 def testDynamicShapesWithGetNextOutsideFunction(self, distribution): 568 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 569 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 570 571 @def_function.function 572 def run(inputs): 573 def computation(x): 574 return math_ops.reduce_mean(x) 575 outputs = distribution.experimental_local_results( 576 distribution.run(computation, args=(inputs,))) 577 return outputs 578 579 # This assumes that there are exactly 2 replicas 580 self.assertAllEqual([5.5, 7.], run(next(input_iterator))) 581 582 @combinations.generate( 583 combinations.combine( 584 distribution=strategy_combinations.multidevice_strategies, 585 mode=["eager"] 586 )) 587 def testStrategyReduceWithDynamicShapes(self, distribution): 588 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 589 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 590 591 @def_function.function 592 def run(iterator): 593 inputs = next(iterator) 594 return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0) 595 596 self.assertAllEqual(6., run(input_iterator)) 597 598 @combinations.generate( 599 combinations.combine( 600 distribution=strategy_combinations.multidevice_strategies, 601 mode=["eager"] 602 )) 603 def testStrategyReduceWithDynamicShapesRank2(self, distribution): 604 dataset = get_dataset_from_tensor_slices( 605 [[1., 1.], [1., 1.], [1., 1.]]).batch(4) 606 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 607 608 @def_function.function 609 def run(iterator): 610 inputs = next(iterator) 611 return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0) 612 613 self.assertAllEqual([1., 1.], run(input_iterator)) 614 615 @combinations.generate( 616 combinations.combine( 617 distribution=strategy_combinations.multidevice_strategies, 618 mode=["eager"] 619 )) 620 def testDynamicShapesWithSizeOp(self, distribution): 621 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 622 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 623 624 @def_function.function 625 def run(inputs): 626 def computation(x): 627 return array_ops.size_v2(x) 628 outputs = distribution.experimental_local_results( 629 distribution.run(computation, args=(inputs,))) 630 return outputs 631 632 # This assumes that there are exactly 2 replicas 633 self.assertAllEqual([2, 1], run(next(input_iterator))) 634 635 @combinations.generate( 636 combinations.combine( 637 distribution=strategy_combinations.multidevice_strategies, 638 mode=["eager"])) 639 def testSegmentSumWithDynamicNumberOfSegments(self, distribution): 640 641 def dataset_fn(_): 642 data = array_ops.zeros(5, dtype=dtypes.int32) 643 dataset = get_dataset_from_tensor_slices(data) 644 dataset = dataset.batch(3) 645 return dataset 646 647 input_iterator = iter( 648 distribution.distribute_datasets_from_function(dataset_fn)) 649 650 @def_function.function 651 def step_fn(example): 652 segment_ids = array_ops.zeros_like_v2(example) 653 num_segment = array_ops.shape(example)[0] 654 # If number of segments is dynamic, output should be a dynamic shape. 655 return math_ops.unsorted_segment_sum(example, segment_ids, num_segment) 656 657 # This assumes that there are exactly 2 replicas 658 outputs = distribution.experimental_local_results( 659 distribution.run(step_fn, args=(next(input_iterator),))) 660 self.assertAllEqual((3,), outputs[0].shape) 661 self.assertAllEqual((2,), outputs[1].shape) 662 663 @combinations.generate( 664 combinations.combine( 665 distribution=strategy_combinations.multidevice_strategies, 666 mode=["eager"])) 667 def testReshapeWithDynamicInputs(self, distribution): 668 669 def dataset_fn(_): 670 data = array_ops.zeros((5, 1, 2), dtype=dtypes.int32) 671 dataset = get_dataset_from_tensor_slices(data) 672 dataset = dataset.batch(3) 673 return dataset 674 675 input_iterator = iter( 676 distribution.distribute_datasets_from_function(dataset_fn)) 677 678 @def_function.function 679 def step_fn(example): 680 # example: [<=3, 1, 2] 681 # tile: [<=3, <=3, 2] 682 tile = array_ops.tile(example, [1, array_ops.shape(example)[0], 1]) 683 # reshape1: [<=(3*3 = 9), 2] 684 reshape1 = array_ops.reshape(tile, [-1, 2]) 685 686 # reshape2: [<=3, <=3, 2] 687 reshape2 = array_ops.reshape( 688 reshape1, 689 [array_ops.shape(example)[0], 690 array_ops.shape(example)[0], 2]) 691 692 # reshape3: [<=3, -1, 2] 693 reshape3 = array_ops.reshape(reshape1, 694 [array_ops.shape(example)[0], -1, 2]) 695 # reshape4: [-1, <=3, 2] 696 reshape4 = array_ops.reshape(reshape1, 697 [-1, array_ops.shape(example)[0], 2]) 698 return [reshape1, reshape2, reshape3, reshape4] 699 700 # This assumes that there are exactly 2 replicas 701 outputs = distribution.experimental_local_results( 702 distribution.run(step_fn, args=(next(input_iterator),))) 703 self.assertAllEqual((9, 2), outputs[0][0].values[0].shape) 704 self.assertAllEqual((3, 3, 2), outputs[0][1].values[0].shape) 705 self.assertAllEqual((3, 3, 2), outputs[0][2].values[0].shape) 706 self.assertAllEqual((3, 3, 2), outputs[0][3].values[0].shape) 707 708 self.assertAllEqual((4, 2), outputs[0][0].values[1].shape) 709 self.assertAllEqual((2, 2, 2), outputs[0][1].values[1].shape) 710 self.assertAllEqual((2, 2, 2), outputs[0][2].values[1].shape) 711 self.assertAllEqual((2, 2, 2), outputs[0][3].values[1].shape) 712 713 @combinations.generate( 714 combinations.combine( 715 distribution=strategy_combinations.multidevice_strategies, 716 mode=["eager"])) 717 def testDynamicShapesWithFirstReplicaNotMaximumShape(self, distribution): 718 def dataset_fn(_): 719 dataset1 = get_dataset_from_tensor_slices([[1., 2.], [1., 2.]]) 720 dataset2 = get_dataset_from_tensor_slices([[1., 2., 3.], 721 [1., 2., 3.]]) 722 dataset = dataset1.concatenate(dataset2) 723 dataset = dataset.batch(2, drop_remainder=True) 724 return dataset 725 726 input_iterator = iter( 727 distribution.distribute_datasets_from_function(dataset_fn)) 728 729 @def_function.function 730 def run(inputs): 731 def computation(x): 732 return math_ops.reduce_mean(x) 733 outputs = distribution.experimental_local_results( 734 distribution.run(computation, args=(inputs,))) 735 return outputs 736 737 # This assumes that there are exactly 2 replicas 738 self.assertAllEqual([1.5, 2.], run(next(input_iterator))) 739 740 @combinations.generate( 741 combinations.combine( 742 distribution=strategy_combinations.multidevice_strategies, 743 mode=["eager"])) 744 def testMapFnWithDynamicInputs(self, distribution): 745 746 def dataset_fn(_): 747 data = array_ops.zeros((20, 300, 32), dtype=dtypes.int32) 748 dataset = get_dataset_from_tensor_slices(data) 749 dataset = dataset.batch(16) 750 return dataset 751 752 input_iterator = iter( 753 distribution.distribute_datasets_from_function(dataset_fn)) 754 755 def embedding_lookup(inputs): 756 embedding_weights = array_ops.zeros((1, 128)) 757 flat_inputs = array_ops.reshape(inputs, [-1]) 758 embeddings = array_ops.gather(embedding_weights, flat_inputs) 759 embeddings = array_ops.reshape(embeddings, inputs.shape.as_list() + [128]) 760 return embeddings 761 762 @def_function.function 763 def step_fn(example): 764 return map_fn.map_fn( 765 embedding_lookup, example, fn_output_signature=dtypes.float32) 766 767 # This assumes that there are exactly 2 replicas 768 outputs = distribution.experimental_local_results( 769 distribution.run(step_fn, args=(next(input_iterator),))) 770 self.assertAllEqual((16, 300, 32, 128), outputs[0].shape) 771 self.assertAllEqual((4, 300, 32, 128), outputs[1].shape) 772 773 @combinations.generate( 774 combinations.combine( 775 distribution=strategy_combinations.all_strategies, 776 mode=["eager"] 777 )) 778 def testDatasetDistributeEvenlyDivisibleDrop(self, distribution): 779 # If the batch size is evenly divisible by the number of workers and we set 780 # drop_remainder=True on the dataset, then DistributedIterator will use a 781 # different (and more efficient) code path which avoids some control flow 782 # ops. 783 dataset = get_dataset_from_tensor_slices([5., 6.]).batch( 784 2, drop_remainder=True) 785 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 786 787 data = next(input_iterator) 788 789 expected_result = [5., 6.] 790 final_result = [] 791 actual_result = distribution.experimental_local_results(data) 792 for val in actual_result: 793 final_result.extend(val) 794 self.assertAllEqual(expected_result, final_result) 795 796 @combinations.generate( 797 combinations.combine( 798 distribution=strategy_combinations.all_strategies, 799 mode=["eager"] 800 )) 801 def testDatasetDistributeNotDivisibleDrop(self, distribution): 802 # If each batch is not evenly divisible by the number of workers, 803 # the remainder will be dropped. 804 dataset = get_dataset_from_tensor_slices([5., 6.]).batch( 805 1, drop_remainder=True) 806 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 807 808 data = next(input_iterator) 809 810 expected_result = [5.] 811 final_result = [] 812 actual_result = distribution.experimental_local_results(data) 813 for val in actual_result: 814 final_result.extend(val) 815 self.assertAllEqual(expected_result, final_result) 816 817 @combinations.generate( 818 combinations.combine( 819 distribution=strategy_combinations.all_strategies, 820 mode=["eager"] 821 )) 822 def testDatasetDistributeEvenlyDivisibleNoDrop(self, distribution): 823 # Setting drop_remainder=False on the dataset causes DistributedIterator 824 # to use get_next_as_optional(), even if the batched dataset is evenly 825 # divisible by the number of workers. 826 dataset = get_dataset_from_tensor_slices([5., 6.]).batch( 827 2, drop_remainder=False) 828 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 829 830 data = next(input_iterator) 831 832 expected_result = [5., 6.] 833 final_result = [] 834 actual_result = distribution.experimental_local_results(data) 835 for val in actual_result: 836 final_result.extend(val) 837 self.assertAllEqual(expected_result, final_result) 838 839 @combinations.generate( 840 combinations.combine( 841 distribution=strategy_combinations.all_strategies, 842 mode=["eager"] 843 )) 844 def testDatasetPartialBatchWithMixedOutputs(self, distribution): 845 # Dynamic output size with a mix of static and dynamic outputs 846 dataset = get_dataset_from_tensor_slices([5.]).batch(2) 847 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 848 849 @def_function.function 850 def run(iterator): 851 852 def computation(x): 853 # Fixed size output with a dynamic sized output. 854 return array_ops.zeros([3]), math_ops.square(x) 855 856 return distribution.run( 857 computation, args=(next(iterator),)) 858 859 results = run(input_iterator) 860 861 # First result is fixed for all replicas. 862 for replica_id in range(distribution.num_replicas_in_sync): 863 self.assertAllEqual([0., 0., 0.], 864 distribution.experimental_local_results( 865 results[0])[replica_id]) 866 # Only first replica has distributed dataset computation. 867 self.assertAllEqual([25.], 868 distribution.experimental_local_results(results[1])[0]) 869 # Other replicas have no distributed dataset computation. 870 for replica_id in range(1, distribution.num_replicas_in_sync): 871 self.assertAllEqual([], 872 distribution.experimental_local_results( 873 results[1])[replica_id]) 874 875 @combinations.generate( 876 combinations.combine( 877 distribution=strategy_combinations.all_strategies, 878 mode=["eager"] 879 )) 880 def testIterationInsideFunction(self, distribution): 881 882 def step_fn(data): 883 return math_ops.square(data) 884 885 @def_function.function 886 def train(dataset): 887 results = [] 888 iterator = iter(dataset) 889 # we iterate through the loop 2 times since we have 4 elements and a 890 # global batch of 2. 891 for _ in range(2): 892 elem = next(iterator) 893 output = distribution.experimental_local_results( 894 distribution.run(step_fn, args=(elem,))) 895 results.append(output) 896 return results 897 898 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 899 dist_dataset = distribution.experimental_distribute_dataset(dataset) 900 results = train(dist_dataset) 901 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 902 903 @combinations.generate( 904 combinations.combine( 905 distribution=strategy_combinations.all_strategies, 906 mode=["eager"] 907 )) 908 def testIterationOutsideFunction(self, distribution): 909 910 def train_step(data): 911 return math_ops.square(data) 912 913 @def_function.function 914 def f_train_step(input_data): 915 return distribution.experimental_local_results( 916 distribution.run(train_step, args=(input_data,))) 917 918 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 919 dist_dataset = distribution.experimental_distribute_dataset(dataset) 920 iterator = iter(dist_dataset) 921 results = [] 922 # we iterate through the loop 2 times since we have 4 elements and a 923 # global batch of 2. 924 for _ in range(2): 925 output = f_train_step(next(iterator)) 926 results.append(output) 927 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 928 929 @combinations.generate( 930 combinations.combine( 931 distribution=strategy_combinations.all_strategies, 932 mode=["eager"] 933 )) 934 def testMultiDeviceDataCapturedFunction(self, distribution): 935 inputs = constant_op.constant([2., 3.]) 936 dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5) 937 input_iterator = iter( 938 distribution.distribute_datasets_from_function(dataset)) 939 with distribution.scope(): 940 var = variables.Variable(1.0) 941 942 @def_function.function 943 def train_step(input_iterator): 944 945 def func(inputs): 946 return math_ops.square(inputs) + var 947 948 per_replica_outputs = distribution.run( 949 func, (next(input_iterator),)) 950 mean = distribution.reduce( 951 reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None) 952 for _ in dataset_ops.Dataset.range(1): 953 per_replica_outputs = distribution.run( 954 func, (next(input_iterator),)) 955 mean = distribution.reduce( 956 reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None) 957 return mean 958 959 with distribution.scope(): 960 if distribution.num_replicas_in_sync == 1: 961 self.assertAlmostEqual(10.0, self.evaluate(train_step(input_iterator))) 962 else: 963 self.assertAlmostEqual(7.5, self.evaluate(train_step(input_iterator))) 964 965 @combinations.generate( 966 combinations.combine( 967 distribution=strategy_combinations.all_strategies, 968 mode=["eager"] 969 )) 970 def testDatasetOutOfRange(self, distribution): 971 with distribution.scope(): 972 a = variables.Variable( 973 0.0, aggregation=variables.VariableAggregation.SUM) 974 975 def train_step(val): 976 a.assign_add(math_ops.reduce_sum(val)) 977 978 @def_function.function 979 def f_train_step(iterator): 980 distribution.run(train_step, args=(next(iterator),)) 981 return a 982 983 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 984 dist_dataset = distribution.experimental_distribute_dataset(dataset) 985 986 iterator = iter(dist_dataset) 987 with self.assertRaises(errors.OutOfRangeError): 988 for _ in range(100): 989 f_train_step(iterator) 990 991 self.assertAlmostEqual(26.0, a.numpy()) 992 993 994if __name__ == "__main__": 995 test_util.main() 996