1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for custom training loops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from absl.testing import parameterized 24import numpy as np 25 26from tensorflow.python import keras 27from tensorflow.python import tf2 28from tensorflow.python.data.ops import dataset_ops 29from tensorflow.python.distribute import combinations 30from tensorflow.python.distribute import reduce_util 31from tensorflow.python.distribute import strategy_combinations 32from tensorflow.python.eager import backprop 33from tensorflow.python.eager import def_function 34from tensorflow.python.eager import test 35from tensorflow.python.framework import constant_op 36from tensorflow.python.framework import ops 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import control_flow_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import variables 41from tensorflow.python.util import nest 42 43 44def get_dataset_from_tensor_slices(inp_array): 45 dataset = dataset_ops.DatasetV2.from_tensor_slices(inp_array) 46 # TODO(b/138326910): Remove Dataset V1 version once bug resolved. 47 if not tf2.enabled(): 48 dataset = dataset_ops.Dataset.from_tensor_slices(inp_array) 49 return dataset 50 51 52class AssertFlattenedMixin(object): 53 """Mixin for specialized asserts.""" 54 55 def assert_equal_flattened(self, expected_results, actual_results): 56 """Asserts that flattened results are equal. 57 58 Due to the number of replicas in the strategy, the output may have a 59 different structure and needs to be flattened for comparison. 60 61 Args: 62 expected_results: The results expected as a result of a computation. 63 actual_results: The actual results of a computation. 64 """ 65 self.assertEqual(len(expected_results), len(actual_results)) 66 67 for i, expected_result in enumerate(expected_results): 68 final_result = [] 69 actual_result = actual_results[i] 70 for val in actual_result: 71 final_result.extend(val.numpy()) 72 self.assertAllEqual(expected_result, final_result) 73 74 75class InputIterationTest(test.TestCase, parameterized.TestCase, 76 AssertFlattenedMixin): 77 78 @combinations.generate( 79 combinations.combine( 80 distribution=strategy_combinations.all_strategies, 81 mode=["eager"] 82 )) 83 def testConstantNumpyInput(self, distribution): 84 85 @def_function.function 86 def run(x): 87 88 def computation(x): 89 return math_ops.square(x) 90 91 outputs = distribution.experimental_local_results( 92 distribution.experimental_run_v2(computation, args=(x,))) 93 return outputs 94 95 self.assertAllEqual( 96 constant_op.constant(4., shape=(distribution.num_replicas_in_sync)), 97 run(2.)) 98 99 @combinations.generate( 100 combinations.combine( 101 distribution=strategy_combinations.all_strategies, 102 mode=["eager"] 103 )) 104 def testStatefulExperimentalRunAlwaysExecute(self, distribution): 105 with distribution.scope(): 106 v = variables.Variable( 107 0.0, aggregation=variables.VariableAggregation.MEAN) 108 109 @def_function.function 110 def train_step(): 111 112 def assign_add(): 113 v.assign_add(1.0) 114 115 distribution.experimental_run_v2(assign_add) 116 return array_ops.zeros([]) 117 118 train_step() 119 self.assertAllEqual(1.0, v.numpy()) 120 121 @combinations.generate( 122 combinations.combine( 123 distribution=strategy_combinations.strategies_minus_tpu, 124 mode=["eager"])) 125 def testFullEager(self, distribution): 126 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 127 128 def train_step(data): 129 return math_ops.square(data) 130 131 dist_dataset = distribution.experimental_distribute_dataset(dataset) 132 results = [] 133 for x in dist_dataset: 134 output = distribution.experimental_local_results( 135 distribution.experimental_run_v2(train_step, args=(x,))) 136 results.append(output) 137 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 138 139 @combinations.generate( 140 combinations.combine( 141 distribution=strategy_combinations.all_strategies, 142 mode=["eager"] 143 )) 144 def testStepInFunction(self, distribution): 145 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 146 147 @def_function.function 148 def train_step(data): 149 return math_ops.square(data) 150 151 dist_dataset = distribution.experimental_distribute_dataset(dataset) 152 results = [] 153 for x in dist_dataset: 154 output = distribution.experimental_local_results( 155 distribution.experimental_run_v2(train_step, args=(x,))) 156 results.append(output) 157 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 158 159 @combinations.generate( 160 combinations.combine( 161 distribution=strategy_combinations.all_strategies, 162 mode=["eager"] 163 )) 164 def testRunInFunction(self, distribution): 165 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 166 167 def train_step(data): 168 return math_ops.square(data) 169 170 @def_function.function 171 def f_train_step(input_data): 172 return distribution.experimental_local_results( 173 distribution.experimental_run_v2(train_step, args=(input_data,))) 174 175 dist_dataset = distribution.experimental_distribute_dataset(dataset) 176 results = [] 177 for x in dist_dataset: 178 output = f_train_step(x) 179 results.append(output) 180 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 181 182 @combinations.generate( 183 combinations.combine( 184 distribution=[ 185 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 186 strategy_combinations.tpu_strategy 187 ], 188 mode=["eager"])) 189 def testNestedOutput(self, distribution): 190 dataset = get_dataset_from_tensor_slices([0, 1, 2, 3]).batch(2) 191 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 192 193 @def_function.function 194 def run(iterator): 195 196 def computation(x): 197 return [{ 198 "a": x - 1, 199 "b": x + 1 200 }] 201 202 inputs = next(iterator) 203 outputs = distribution.experimental_run_v2(computation, args=(inputs,)) 204 return nest.map_structure(distribution.experimental_local_results, 205 outputs) 206 207 results = run(input_iterator) 208 for replica in range(distribution.num_replicas_in_sync): 209 # The input dataset is range(4), so the replica id is same as input. 210 self.assertAllEqual(results[0]["a"][replica], [replica - 1]) 211 self.assertAllEqual(results[0]["b"][replica], [replica + 1]) 212 213 @combinations.generate( 214 combinations.combine( 215 distribution=strategy_combinations.all_strategies, 216 mode=["eager"] 217 )) 218 def testRunInFunctionAutoGraphApplication(self, distribution): 219 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 220 221 def train_step(data): 222 return math_ops.square(data) 223 224 @def_function.function 225 def f_train_step(input_data): 226 return distribution.experimental_local_results( 227 distribution.experimental_run_v2(train_step, args=(input_data,))) 228 229 dist_dataset = distribution.experimental_distribute_dataset(dataset) 230 results = [] 231 for x in dist_dataset: 232 output = f_train_step(x) 233 results.append(output) 234 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 235 236 @combinations.generate( 237 combinations.combine( 238 distribution=strategy_combinations.all_strategies, 239 mode=["eager"] 240 )) 241 def testDatasetIterationInFunction(self, distribution): 242 with distribution.scope(): 243 a = variables.Variable( 244 1.0, aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) 245 246 def train_step(_): 247 a.assign_add(1.0) 248 249 @def_function.function 250 def f_train_step(dist_dataset): 251 number_of_steps = constant_op.constant(0.0) 252 product_of_means = constant_op.constant(2.0) 253 for x in dist_dataset: # loop with values modified each iteration 254 number_of_steps += 1 255 product_of_means *= math_ops.cast( 256 distribution.reduce("MEAN", x, axis=0), product_of_means.dtype) 257 258 for y in dist_dataset: # loop with no intermediate state 259 distribution.experimental_run_v2(train_step, args=(y,)) 260 261 return number_of_steps, product_of_means 262 263 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 264 dist_dataset = distribution.experimental_distribute_dataset(dataset) 265 266 number_of_steps, product_of_means = f_train_step(dist_dataset) 267 self.assertEqual(2, number_of_steps.numpy()) 268 self.assertNear((2 * (5+6)/2 * (7+8)/2), product_of_means.numpy(), 1e-3) 269 270 # We set the initial value of `a` to 1 and iterate through the dataset 2 271 # times(4/2 where 4 is the number of dataset elements and 2 is the batch 272 # size). Hence the final result is 3. 273 self.assertEqual(3.0, (a.numpy())) 274 275 @combinations.generate( 276 combinations.combine( 277 distribution=strategy_combinations.all_strategies, 278 mode=["eager"] 279 )) 280 def testDatasetAssertWithDynamicBatch(self, distribution): 281 # Regression test for github issue 33517. 282 def step_fn(data): 283 assert_op = control_flow_ops.Assert(math_ops.less_equal( 284 math_ops.reduce_max(data), 100.), [data]) 285 with ops.control_dependencies([assert_op]): 286 return math_ops.square(data) 287 288 @def_function.function 289 def train(dataset): 290 results = [] 291 iterator = iter(dataset) 292 # we iterate through the loop 5 times since we have 3 elements and a 293 # global batch of 2. 294 for _ in range(2): 295 elem = next(iterator) 296 output = distribution.experimental_local_results( 297 distribution.experimental_run_v2(step_fn, args=(elem,))) 298 results.append(output) 299 return results 300 301 dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7.,]).batch(2) 302 # TODO(b/138326910): Remove Dataset V1 version once bug resolved. 303 if not tf2.enabled(): 304 dataset = dataset_ops.Dataset.from_tensor_slices([5., 6., 7.,]).batch(2) 305 dist_dataset = distribution.experimental_distribute_dataset(dataset) 306 results = train(dist_dataset) 307 308 expected_results = [[25., 36.], [49.]] 309 self.assertEqual(len(expected_results), len(results)) 310 311 # Need to expand results since output will be grouped differently depending 312 # on the number of replicas. 313 for i, expected_result in enumerate(expected_results): 314 final_result = [] 315 actual_result = results[i] 316 for val in actual_result: 317 final_result.extend(val.numpy()) 318 self.assertAllEqual(expected_result, final_result) 319 320 @combinations.generate( 321 combinations.combine( 322 distribution=strategy_combinations.multidevice_strategies, 323 mode=["eager"] 324 )) 325 def testDynamicShapes(self, distribution): 326 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 327 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 328 329 @def_function.function 330 def run(iterator): 331 def computation(x): 332 return math_ops.reduce_mean(x) 333 inputs = next(iterator) 334 outputs = distribution.experimental_local_results( 335 distribution.experimental_run_v2(computation, args=(inputs,))) 336 return outputs 337 338 # This assumes that there are exactly 2 replicas 339 self.assertAllEqual([5.5, 7.], run(input_iterator)) 340 341 @combinations.generate( 342 combinations.combine( 343 distribution=strategy_combinations.multidevice_strategies, 344 mode=["eager"] 345 )) 346 def testDynamicShapesWithGetNextOutsideFunction(self, distribution): 347 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 348 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 349 350 @def_function.function 351 def run(inputs): 352 def computation(x): 353 return math_ops.reduce_mean(x) 354 outputs = distribution.experimental_local_results( 355 distribution.experimental_run_v2(computation, args=(inputs,))) 356 return outputs 357 358 # This assumes that there are exactly 2 replicas 359 self.assertAllEqual([5.5, 7.], run(next(input_iterator))) 360 361 @combinations.generate( 362 combinations.combine( 363 distribution=strategy_combinations.multidevice_strategies, 364 mode=["eager"] 365 )) 366 def testStrategyReduceWithDynamicShapes(self, distribution): 367 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 368 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 369 370 @def_function.function 371 def run(iterator): 372 inputs = next(iterator) 373 return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0) 374 375 self.assertAllEqual(6., run(input_iterator)) 376 377 @combinations.generate( 378 combinations.combine( 379 distribution=strategy_combinations.multidevice_strategies, 380 mode=["eager"] 381 )) 382 def testStrategyReduceWithDynamicShapesRank2(self, distribution): 383 dataset = get_dataset_from_tensor_slices( 384 [[1., 1.], [1., 1.], [1., 1.]]).batch(4) 385 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 386 387 @def_function.function 388 def run(iterator): 389 inputs = next(iterator) 390 return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0) 391 392 self.assertAllEqual([1., 1.], run(input_iterator)) 393 394 @combinations.generate( 395 combinations.combine( 396 distribution=strategy_combinations.multidevice_strategies, 397 mode=["eager"] 398 )) 399 def testDynamicShapesWithSizeOp(self, distribution): 400 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 401 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 402 403 @def_function.function 404 def run(inputs): 405 def computation(x): 406 return array_ops.size_v2(x) 407 outputs = distribution.experimental_local_results( 408 distribution.experimental_run_v2(computation, args=(inputs,))) 409 return outputs 410 411 # This assumes that there are exactly 2 replicas 412 self.assertAllEqual([2, 1], run(next(input_iterator))) 413 414 @combinations.generate( 415 combinations.combine( 416 distribution=strategy_combinations.multidevice_strategies, 417 mode=["eager"] 418 )) 419 def testDynamicShapesWithFirstReplicaNotMaximumShape(self, distribution): 420 def dataset_fn(_): 421 dataset1 = get_dataset_from_tensor_slices([[1., 2.], [1., 2.]]) 422 dataset2 = get_dataset_from_tensor_slices([[1., 2., 3.], 423 [1., 2., 3.]]) 424 dataset = dataset1.concatenate(dataset2) 425 dataset = dataset.batch(2, drop_remainder=True) 426 return dataset 427 428 input_iterator = iter( 429 distribution.experimental_distribute_datasets_from_function(dataset_fn)) 430 431 @def_function.function 432 def run(inputs): 433 def computation(x): 434 return math_ops.reduce_mean(x) 435 outputs = distribution.experimental_local_results( 436 distribution.experimental_run_v2(computation, args=(inputs,))) 437 return outputs 438 439 # This assumes that there are exactly 2 replicas 440 self.assertAllEqual([1.5, 2.], run(next(input_iterator))) 441 442 @combinations.generate( 443 combinations.combine( 444 distribution=strategy_combinations.all_strategies, 445 mode=["eager"] 446 )) 447 def testDatasetDistributeEvenlyDivisibleDrop(self, distribution): 448 # If the batch size is evenly divisible by the number of workers and we set 449 # drop_remainder=True on the dataset, then DistributedIterator will use a 450 # different (and more efficient) code path which avoids some control flow 451 # ops. 452 dataset = get_dataset_from_tensor_slices([5., 6.]).batch( 453 2, drop_remainder=True) 454 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 455 456 data = next(input_iterator) 457 458 expected_result = [5., 6.] 459 final_result = [] 460 actual_result = distribution.experimental_local_results(data) 461 for val in actual_result: 462 final_result.extend(val) 463 self.assertAllEqual(expected_result, final_result) 464 465 @combinations.generate( 466 combinations.combine( 467 distribution=strategy_combinations.all_strategies, 468 mode=["eager"] 469 )) 470 def testDatasetDistributeNotDivisibleDrop(self, distribution): 471 # If each batch is not evenly divisible by the number of workers, 472 # the remainder will be dropped. 473 dataset = get_dataset_from_tensor_slices([5., 6.]).batch( 474 1, drop_remainder=True) 475 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 476 477 data = next(input_iterator) 478 479 expected_result = [5.] 480 final_result = [] 481 actual_result = distribution.experimental_local_results(data) 482 for val in actual_result: 483 final_result.extend(val) 484 self.assertAllEqual(expected_result, final_result) 485 486 @combinations.generate( 487 combinations.combine( 488 distribution=strategy_combinations.all_strategies, 489 mode=["eager"] 490 )) 491 def testDatasetDistributeEvenlyDivisibleNoDrop(self, distribution): 492 # Setting drop_remainder=False on the dataset causes DistributedIterator 493 # to use get_next_as_optional(), even if the batched dataset is evenly 494 # divisible by the number of workers. 495 dataset = get_dataset_from_tensor_slices([5., 6.]).batch( 496 2, drop_remainder=False) 497 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 498 499 data = next(input_iterator) 500 501 expected_result = [5., 6.] 502 final_result = [] 503 actual_result = distribution.experimental_local_results(data) 504 for val in actual_result: 505 final_result.extend(val) 506 self.assertAllEqual(expected_result, final_result) 507 508 @combinations.generate( 509 combinations.combine( 510 distribution=strategy_combinations.all_strategies, 511 mode=["eager"] 512 )) 513 def testDatasetPartialBatchWithMixedOutputs(self, distribution): 514 # Dynamic output size with a mix of static and dynamic outputs 515 dataset = get_dataset_from_tensor_slices([5.]).batch(2) 516 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 517 518 @def_function.function 519 def run(iterator): 520 521 def computation(x): 522 # Fixed size output with a dynamic sized output. 523 return array_ops.zeros([3]), math_ops.square(x) 524 525 return distribution.experimental_run_v2( 526 computation, args=(next(iterator),)) 527 528 results = run(input_iterator) 529 530 # First result is fixed for all replicas. 531 for replica_id in range(distribution.num_replicas_in_sync): 532 self.assertAllEqual([0., 0., 0.], 533 distribution.experimental_local_results( 534 results[0])[replica_id]) 535 # Only first replica has distributed dataset computation. 536 self.assertAllEqual([25.], 537 distribution.experimental_local_results(results[1])[0]) 538 # Other replicas have no distributed dataset computation. 539 for replica_id in range(1, distribution.num_replicas_in_sync): 540 self.assertAllEqual([], 541 distribution.experimental_local_results( 542 results[1])[replica_id]) 543 544 @combinations.generate( 545 combinations.combine( 546 distribution=strategy_combinations.all_strategies, 547 mode=["eager"] 548 )) 549 def testIterationInsideFunction(self, distribution): 550 551 def step_fn(data): 552 return math_ops.square(data) 553 554 @def_function.function 555 def train(dataset): 556 results = [] 557 iterator = iter(dataset) 558 # we iterate through the loop 2 times since we have 4 elements and a 559 # global batch of 2. 560 for _ in range(2): 561 elem = next(iterator) 562 output = distribution.experimental_local_results( 563 distribution.experimental_run_v2(step_fn, args=(elem,))) 564 results.append(output) 565 return results 566 567 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 568 dist_dataset = distribution.experimental_distribute_dataset(dataset) 569 results = train(dist_dataset) 570 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 571 572 @combinations.generate( 573 combinations.combine( 574 distribution=strategy_combinations.all_strategies, 575 mode=["eager"] 576 )) 577 def testIterationOutsideFunction(self, distribution): 578 579 def train_step(data): 580 return math_ops.square(data) 581 582 @def_function.function 583 def f_train_step(input_data): 584 return distribution.experimental_local_results( 585 distribution.experimental_run_v2(train_step, args=(input_data,))) 586 587 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 588 dist_dataset = distribution.experimental_distribute_dataset(dataset) 589 iterator = iter(dist_dataset) 590 results = [] 591 # we iterate through the loop 2 times since we have 4 elements and a 592 # global batch of 2. 593 for _ in range(2): 594 output = f_train_step(next(iterator)) 595 results.append(output) 596 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 597 598 599class GradientTapeTest(test.TestCase, parameterized.TestCase, 600 AssertFlattenedMixin): 601 602 @combinations.generate( 603 combinations.combine( 604 distribution=strategy_combinations.all_strategies, 605 mode=["eager"] 606 )) 607 def testStepInFunctionGradient(self, distribution): 608 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 609 610 @def_function.function 611 def train_step(x): 612 def computation(x): 613 return math_ops.square(x) 614 with backprop.GradientTape() as tape: 615 tape.watch(x) # Manually watch non-variable tensors. 616 y = computation(x) 617 grads = tape.gradient(y, x) 618 return grads 619 620 dist_dataset = distribution.experimental_distribute_dataset(dataset) 621 results = [] 622 for x in dist_dataset: 623 output = distribution.experimental_local_results( 624 distribution.experimental_run_v2(train_step, args=(x,))) 625 results.append(output) 626 self.assert_equal_flattened([[10., 12.], [14., 16.]], results) 627 628 @combinations.generate( 629 combinations.combine( 630 distribution=strategy_combinations.all_strategies, 631 mode=["eager"] 632 )) 633 def testRunInFunctionGradient(self, distribution): 634 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 635 636 @def_function.function 637 def run(x): 638 def train_step(x): 639 def computation(x): 640 return math_ops.square(x) 641 with backprop.GradientTape() as tape: 642 tape.watch(x) # Manually watch non-variable tensors. 643 y = computation(x) 644 grads = tape.gradient(y, x) 645 return grads 646 return distribution.experimental_local_results( 647 distribution.experimental_run_v2(train_step, args=(x,))) 648 649 dist_dataset = distribution.experimental_distribute_dataset(dataset) 650 results = [] 651 for x in dist_dataset: 652 output = run(x) 653 results.append(output) 654 self.assert_equal_flattened([[10., 12.], [14., 16.]], results) 655 656 @combinations.generate( 657 combinations.combine( 658 distribution=strategy_combinations.all_strategies, 659 mode=["eager"], 660 model_in_tf_function=[True, False] 661 )) 662 def testNestedFunction(self, distribution, model_in_tf_function): 663 def model(x): 664 return x * x 665 666 if model_in_tf_function: 667 model = def_function.function(model) 668 669 with distribution.scope(): 670 x = variables.Variable(1.0) 671 672 @def_function.function 673 def train_step(): 674 def replica_step(): 675 with backprop.GradientTape() as tape: 676 y = model(x) 677 return tape.gradient(y, x) 678 return distribution.experimental_run_v2(replica_step) 679 680 grads = distribution.experimental_local_results(train_step()) 681 self.assertLen(grads, distribution.num_replicas_in_sync) 682 self.assertTrue(all(g is not None for g in grads)) 683 684 685class KerasModelsTest(test.TestCase, parameterized.TestCase): 686 687 @combinations.generate( 688 combinations.combine( 689 distribution=strategy_combinations.all_strategies, 690 mode=["eager"] 691 )) 692 def test_single_keras_layer_experimental_run(self, distribution): 693 dataset = self._get_dataset() 694 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 695 696 with distribution.scope(): 697 model = keras.layers.Dense(4, name="dense") 698 699 @def_function.function 700 def train_step(iterator): 701 def step_fn(inputs): 702 images, targets = inputs 703 with backprop.GradientTape() as tape: 704 outputs = model(images) 705 loss = math_ops.reduce_sum(outputs - targets) 706 grads = tape.gradient(loss, model.variables) 707 return grads 708 709 outputs = distribution.experimental_run_v2( 710 step_fn, args=(next(iterator),)) 711 return nest.map_structure(distribution.experimental_local_results, 712 outputs) 713 714 train_step(input_iterator) 715 716 @combinations.generate( 717 combinations.combine( 718 distribution=strategy_combinations.all_strategies, 719 mode=["eager"] 720 )) 721 def test_keras_model_creation_experimental_run(self, distribution): 722 dataset = self._get_dataset() 723 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 724 725 with distribution.scope(): 726 model = self._get_model() 727 728 @def_function.function 729 def train_step(iterator): 730 def step_fn(inputs): 731 images, targets = inputs 732 with backprop.GradientTape() as tape: 733 outputs = model(images) 734 loss = math_ops.reduce_sum(outputs - targets) 735 grads = tape.gradient(loss, model.variables) 736 return grads 737 738 outputs = distribution.experimental_run_v2( 739 step_fn, args=(next(iterator),)) 740 return nest.map_structure(distribution.experimental_local_results, 741 outputs) 742 743 train_step(input_iterator) 744 745 @combinations.generate( 746 combinations.combine( 747 distribution=strategy_combinations.all_strategies, 748 mode=["eager"] 749 )) 750 def test_keras_model_optimizer_experimental_run(self, distribution): 751 dataset = self._get_dataset() 752 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 753 754 with distribution.scope(): 755 model = self._get_model() 756 optimizer = keras.optimizer_v2.rmsprop.RMSprop() 757 758 @def_function.function 759 def train_step(iterator): 760 def step_fn(inputs): 761 images, targets = inputs 762 with backprop.GradientTape() as tape: 763 outputs = model(images) 764 loss = math_ops.reduce_sum(outputs - targets) 765 grads = tape.gradient(loss, model.variables) 766 optimizer.apply_gradients(zip(grads, model.variables)) 767 return loss 768 769 outputs = distribution.experimental_run_v2( 770 step_fn, args=(next(iterator),)) 771 return nest.map_structure(distribution.experimental_local_results, 772 outputs) 773 774 train_step(input_iterator) 775 776 @combinations.generate( 777 combinations.combine( 778 distribution=strategy_combinations.all_strategies, 779 mode=["eager"] 780 )) 781 def test_keras_subclass_model_optimizer_experimental_run(self, distribution): 782 def get_subclass_model(): 783 784 class KerasSubclassModel(keras.Model): 785 786 def __init__(self): 787 super(KerasSubclassModel, self).__init__() 788 self.l = keras.layers.Dense(4, name="dense") 789 790 def call(self, x): 791 return self.l(x) 792 793 return KerasSubclassModel() 794 dataset = self._get_dataset() 795 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 796 797 with distribution.scope(): 798 model = get_subclass_model() 799 optimizer = keras.optimizer_v2.rmsprop.RMSprop() 800 801 @def_function.function 802 def train_step(iterator): 803 def step_fn(inputs): 804 images, targets = inputs 805 with backprop.GradientTape() as tape: 806 outputs = model(images) 807 loss = math_ops.reduce_sum(outputs - targets) 808 grads = tape.gradient(loss, model.variables) 809 optimizer.apply_gradients(zip(grads, model.variables)) 810 return loss 811 812 outputs = distribution.experimental_run_v2( 813 step_fn, args=(next(iterator),)) 814 return nest.map_structure(distribution.experimental_local_results, 815 outputs) 816 817 train_step(input_iterator) 818 819 @combinations.generate( 820 combinations.combine( 821 distribution=strategy_combinations.all_strategies, 822 mode=["eager"] 823 )) 824 def test_keras_model_optimizer_experimental_run_loop(self, distribution): 825 dataset = self._get_dataset() 826 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 827 828 with distribution.scope(): 829 model = self._get_model() 830 optimizer = keras.optimizer_v2.rmsprop.RMSprop() 831 832 @def_function.function 833 def train_step(iterator): 834 def step_fn(inputs): 835 images, targets = inputs 836 with backprop.GradientTape() as tape: 837 outputs = model(images) 838 loss = math_ops.reduce_sum(outputs - targets) 839 grads = tape.gradient(loss, model.variables) 840 optimizer.apply_gradients(zip(grads, model.variables)) 841 return loss 842 843 for _ in range(5): 844 distribution.experimental_run_v2(step_fn, args=(next(iterator),)) 845 846 train_step(input_iterator) 847 848 @combinations.generate( 849 combinations.combine( 850 distribution=strategy_combinations.all_strategies, 851 mode=["eager"] 852 )) 853 def test_lstm(self, distribution): 854 855 batch_size = 32 856 857 def create_lstm_model(): 858 model = keras.models.Sequential() 859 # We only have LSTM variables so we can detect no gradient issues more 860 # easily. 861 model.add( 862 keras.layers.LSTM(1, return_sequences=False, input_shape=(10, 1))) 863 return model 864 865 def create_lstm_data(): 866 seq_length = 10 867 868 x_train = np.random.rand(batch_size, seq_length, 1).astype("float32") 869 y_train = np.random.rand(batch_size, 1).astype("float32") 870 return x_train, y_train 871 872 x, y = create_lstm_data() 873 dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) 874 dataset = dataset.batch(batch_size, drop_remainder=True) 875 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 876 877 with distribution.scope(): 878 model = create_lstm_model() 879 optimizer = keras.optimizer_v2.gradient_descent.SGD() 880 881 @def_function.function 882 def train_step(input_iterator): 883 884 def step_fn(inputs): 885 inps, targ = inputs 886 with backprop.GradientTape() as tape: 887 output = model(inps) 888 loss = math_ops.reduce_mean( 889 keras.losses.binary_crossentropy( 890 y_true=targ, y_pred=output, from_logits=False)) 891 grads = tape.gradient(loss, model.variables) 892 optimizer.apply_gradients(zip(grads, model.variables)) 893 return loss 894 895 outputs = distribution.experimental_run_v2( 896 step_fn, args=(next(input_iterator),)) 897 return distribution.experimental_local_results(outputs) 898 899 train_step(input_iterator) 900 901 @combinations.generate( 902 combinations.combine( 903 distribution=strategy_combinations.all_strategies, mode=["eager"])) 904 def test_nested_tf_functions(self, distribution): 905 # The test builds two computations with keras layers, one with nested 906 # tf.function, and the other without nested tf.function. We run these 907 # computations independently on the model with same weights, and make sure 908 # the variables are still the same after one training step. 909 910 inputs = np.random.random((10, 3)).astype(np.float32) 911 targets = np.ones((10, 4), dtype=np.float32) 912 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).repeat() 913 dataset = dataset.batch(10, drop_remainder=True) 914 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 915 916 def get_model(): 917 x = keras.layers.Input(shape=(3,), name="input") 918 y = keras.layers.Dense(4, name="dense")(x) 919 model = keras.Model(x, y) 920 return model 921 922 with distribution.scope(): 923 model = get_model() 924 optimizer = keras.optimizer_v2.gradient_descent.SGD(0.1, momentum=0.01) 925 weights_file = os.path.join(self.get_temp_dir(), ".h5") 926 model.save_weights(weights_file) 927 model2 = get_model() 928 model2.load_weights(weights_file) 929 930 # Make sure model and model2 variables are in sync when initialized. 931 for model_v, model2_v in zip(model.variables, model2.variables): 932 self.assertAllClose(model_v.numpy(), model2_v.numpy()) 933 934 def compute_loss(images, targets): 935 outputs = model(images) 936 return math_ops.reduce_sum(outputs - targets) 937 938 @def_function.function 939 def train_step_without_nested_tf_function(inputs): 940 941 def step_fn(inputs): 942 images, targets = inputs 943 with backprop.GradientTape() as tape: 944 loss = compute_loss(images, targets) 945 grads = tape.gradient(loss, model.variables) 946 optimizer.apply_gradients(zip(grads, model.variables)) 947 948 distribution.experimental_run_v2(step_fn, args=(inputs,)) 949 950 @def_function.function 951 def compute_loss2(images, targets): 952 outputs = model2(images) 953 return math_ops.reduce_sum(outputs - targets) 954 955 @def_function.function 956 def train_step_with_nested_tf_function(inputs): 957 958 def step_fn(inputs): 959 images, targets = inputs 960 with backprop.GradientTape() as tape: 961 loss = compute_loss2(images, targets) 962 grads = tape.gradient(loss, model2.variables) 963 optimizer.apply_gradients(zip(grads, model2.variables)) 964 965 distribution.experimental_run_v2(step_fn, args=(inputs,)) 966 967 inputs = next(input_iterator) 968 969 train_step_without_nested_tf_function(inputs) 970 train_step_with_nested_tf_function(inputs) 971 972 # Make sure model and model2 variables are still in sync. 973 for model_v, model2_v in zip(model.variables, model2.variables): 974 self.assertAllClose(model_v.numpy(), model2_v.numpy()) 975 976 def _get_dataset(self): 977 inputs = np.zeros((10, 3), dtype=np.float32) 978 targets = np.zeros((10, 4), dtype=np.float32) 979 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 980 dataset = dataset.repeat(100) 981 dataset = dataset.batch(10, drop_remainder=True) 982 return dataset 983 984 def _get_model(self): 985 x = keras.layers.Input(shape=(3,), name="input") 986 y = keras.layers.Dense(4, name="dense")(x) 987 model = keras.Model(x, y) 988 return model 989 990 991if __name__ == "__main__": 992 test.main() 993