1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library for testing DistributionStrategy descendants.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import os 23import tempfile 24 25import numpy as np 26 27from tensorflow.core.protobuf import config_pb2 28from tensorflow.core.util import event_pb2 29from tensorflow.python.client import session as session_lib 30from tensorflow.python.data.ops import dataset_ops 31from tensorflow.python.distribute import distribute_lib 32from tensorflow.python.distribute import distribute_utils 33from tensorflow.python.distribute import distribution_strategy_context as ds_context 34from tensorflow.python.distribute import reduce_util 35from tensorflow.python.eager import backprop 36from tensorflow.python.eager import context 37from tensorflow.python.eager import def_function 38from tensorflow.python.eager import test 39from tensorflow.python.framework import dtypes 40from tensorflow.python.framework import errors 41from tensorflow.python.framework import ops 42from tensorflow.python.framework import test_util 43from tensorflow.python.lib.io import tf_record 44from tensorflow.python.ops import array_ops 45from tensorflow.python.ops import gen_math_ops 46from tensorflow.python.ops import gradients_impl 47from tensorflow.python.ops import init_ops 48from tensorflow.python.ops import init_ops_v2 49from tensorflow.python.ops import summary_ops_v2 as summary_ops 50from tensorflow.python.ops import variable_scope 51from tensorflow.python.ops import variables 52from tensorflow.python.platform import gfile 53from tensorflow.python.training import optimizer 54from tensorflow.python.training import training_util 55from tensorflow.python.util import nest 56from tensorflow.python.util import tf_inspect 57 58 59class _TestException(Exception): 60 pass 61 62 63# Conditionally wrap the fn in a def_function.function (so it runs in graph 64# mode). 65def _maybe_run_in_function(fn, run_in_function=False): 66 if not run_in_function or not context.executing_eagerly(): 67 return fn 68 else: 69 return def_function.function()(fn) 70 71 72# May be the argument to either distribution.extended.call_for_each_replica() or 73# get_replica_context().merge_call() 74def _raise_exception_fn(_=None): 75 raise _TestException() 76 77 78# Must be the argument to a distribution.extended.call_for_each_replica() call, 79# calls a get_replica_context().merge_call() that raises an exception. 80def _merge_raises_fn(): 81 ds_context.get_replica_context().merge_call(_raise_exception_fn) 82 83 84# Must be the argument to a get_replica_context().merge_call() call, calls 85# dist.extended.call_for_each_replica() with a function that raises an 86# exception. 87def _call_raises_fn(dist): 88 dist.extended.call_for_each_replica(_raise_exception_fn) 89 90 91# Must be the argument to a distribution.extended.call_for_each_replica() call, 92# calls a get_replica_context().merge_call() that calls a 93# call_for_each_replica() that raises an exception. 94def _merge_call_raises_fn(): 95 ds_context.get_replica_context().merge_call(_call_raises_fn) 96 97 98# Must be the argument to a get_replica_context().merge_call() call, calls 99# dist.extended.call_for_each_replica() with a function that calls a 100# get_replica_context().merge_call() that raises an exception. 101def _call_merge_raises_fn(dist): 102 dist.extended.call_for_each_replica(_merge_raises_fn) 103 104 105# Must be the argument to a distribution.extended.call_for_each_replica() call, 106# calls a get_replica_context().merge_call() that calls a 107# call_for_each_replica() that calls a get_replica_context().merge_call() that 108# raises an exception. 109def _merge_call_merge_raises_fn(): 110 ds_context.get_replica_context().merge_call(_call_merge_raises_fn) 111 112 113def _events_from_logdir(test_case, logdir): 114 """Reads summary events from log directory.""" 115 test_case.assertTrue(gfile.Exists(logdir)) 116 files = gfile.ListDirectory(logdir) 117 test_case.assertLen(files, 1) 118 records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) 119 result = [] 120 for r in records: 121 event = event_pb2.Event() 122 event.ParseFromString(r) 123 result.append(event) 124 return result 125 126 127def create_variable_like_keras_layer(name, shape, dtype): 128 """Utitlity for create variables that works like variable in keras layer.""" 129 initializer = functools.partial( 130 init_ops_v2.GlorotUniform(), shape, dtype=dtype) 131 return variables.Variable( 132 initial_value=initializer, name=name, trainable=True) 133 134 135def is_optimizer_v2_instance(optimizer_obj): 136 # For a optimizer instance, the v2 implementation has var_list as a required 137 # argument. 138 arg_spec = tf_inspect.getfullargspec(optimizer_obj.minimize) 139 return "var_list" in arg_spec.args[:-len(arg_spec.defaults)] 140 141 142class DistributionTestBase(test.TestCase): 143 """Some tests that should work with any DistributionStrategy.""" 144 145 def _test_minimize_loss_eager(self, d): 146 with d.scope(): 147 kernel = create_variable_like_keras_layer( 148 name="kernel", shape=(1, 1), dtype=dtypes.float32) 149 def loss(x): 150 y = array_ops.reshape( 151 gen_math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.) 152 return y * y 153 # TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a 154 # common `implicit_grad` function and put it in DistributionStrategy. 155 grad_fn = backprop.implicit_grad(loss) 156 grad_fn = optimizer.get_filtered_grad_fn(grad_fn) 157 158 def update(v, g): 159 return v.assign_sub(0.2 * g) 160 161 one = array_ops.identity([[1.]]) 162 163 def step(): 164 """Perform one optimization step.""" 165 # Run forward & backward to get gradients, variables list. 166 g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) 167 168 # Update the variables using the gradients and the update() function. 169 before_list = [] 170 after_list = [] 171 for g, v in g_v: 172 fetched = d.extended.read_var(v) 173 before_list.append(fetched) 174 # control_dependencies irrelevant but harmless in eager execution 175 with ops.control_dependencies([fetched]): 176 g = d.extended.reduce_to( 177 reduce_util.ReduceOp.SUM, g, destinations=v) 178 with ops.control_dependencies( 179 d.extended.update(v, update, args=(g,), group=False)): 180 after_list.append(d.extended.read_var(v)) 181 return before_list, after_list 182 183 for i in range(10): 184 b, a = step() 185 if i == 0: 186 before, = b # pylint: disable=unbalanced-tuple-unpacking 187 after, = a # pylint: disable=unbalanced-tuple-unpacking 188 189 error_before = abs(before.numpy() - 1) 190 error_after = abs(after.numpy() - 1) 191 # Error should go down 192 self.assertLess(error_after, error_before) 193 194 def _test_minimize_loss_graph(self, 195 d, 196 soft_placement=False, 197 learning_rate=0.2): 198 config = config_pb2.ConfigProto() 199 config.allow_soft_placement = soft_placement 200 config.gpu_options.per_process_gpu_memory_fraction = 0.3 201 with context.graph_mode(), \ 202 ops.Graph().as_default(), \ 203 self.cached_session(config=config) as sess, \ 204 d.scope(): 205 kernel = create_variable_like_keras_layer( 206 name="kernel", shape=(1, 1), dtype=dtypes.float32) 207 208 def loss(x): 209 y = array_ops.reshape( 210 gen_math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.) 211 return y * y 212 213 grad_fn = backprop.implicit_grad(loss) 214 215 def update(v, g): 216 return v.assign_sub(learning_rate * g) 217 218 one = array_ops.identity([[1.]]) 219 220 def step(): 221 """Perform one optimization step.""" 222 # Run forward & backward to get gradients, variables list. 223 g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) 224 225 # Update the variables using the gradients and the update() function. 226 before_list = [] 227 after_list = [] 228 for g, v in g_v: 229 fetched = d.extended.read_var(v) 230 before_list.append(fetched) 231 with ops.control_dependencies([fetched]): 232 g = d.extended.reduce_to( 233 reduce_util.ReduceOp.SUM, g, destinations=v) 234 with ops.control_dependencies( 235 d.extended.update(v, update, args=(g,), group=False)): 236 after_list.append(d.extended.read_var(v)) 237 return before_list, after_list 238 239 before_out, after_out = step() 240 variables.global_variables_initializer().run() 241 for i in range(10): 242 b, a = sess.run((before_out, after_out)) 243 if i == 0: 244 before, = b 245 after, = a 246 247 error_before = abs(before - 1) 248 error_after = abs(after - 1) 249 # Error should go down 250 self.assertLess(error_after, error_before) 251 252 def _test_summary_for_replica_zero_only(self, d): 253 logdir = tempfile.mkdtemp() 254 255 def run_fn(): 256 """Function executed for each replica.""" 257 with summary_writer.as_default(): 258 replica_id = ds_context.get_replica_context().replica_id_in_sync_group 259 return summary_ops.write("a", replica_id) 260 261 with self.cached_session() as sess, d.scope(), \ 262 summary_ops.always_record_summaries(): 263 # We need global_step because summary writing op *always* has global_step 264 # as input, even when we always record summary or never record summary. 265 global_step = training_util.get_or_create_global_step() 266 if not context.executing_eagerly(): 267 # When executing eagerly, variables are initialized immediately after 268 # creation, and its initializer will be None. 269 global_step.initializer.run() 270 summary_ops.set_step(0) 271 summary_writer = summary_ops.create_file_writer(logdir) 272 output = d.extended.call_for_each_replica(run_fn) 273 unwrapped = d.unwrap(output) 274 if not context.executing_eagerly(): 275 sess.run(summary_writer.init()) 276 sess.run(unwrapped) 277 sess.run(summary_writer.close()) 278 279 events = _events_from_logdir(self, logdir) 280 # There will be 2 entries: 1 summary file header entry, and 1 entry 281 # written by replica 0. 282 self.assertLen(events, 2) 283 self.assertEqual(events[1].summary.value[0].tag, "a") 284 self.assertEqual(events[1].summary.value[0].simple_value, 0.0) 285 286 def _test_replica_id(self, d): 287 with d.scope(): 288 expected_devices = [False] * len(d.extended.worker_devices) 289 290 def mark_devices_fn(): 291 replica_id = self.evaluate( 292 ds_context.get_replica_context().replica_id_in_sync_group) 293 self.assertLess(replica_id, len(d.extended.worker_devices)) 294 self.assertFalse(expected_devices[replica_id]) 295 expected_devices[replica_id] = True 296 297 d.extended.call_for_each_replica(mark_devices_fn) 298 self.assertAllEqual(expected_devices, 299 [True] * len(d.extended.worker_devices)) 300 301 def _test_call_and_merge_exceptions(self, dist): 302 with dist.scope(): 303 with self.assertRaises(_TestException): 304 dist.extended.call_for_each_replica(_raise_exception_fn) 305 with self.assertRaises(_TestException): 306 dist.extended.call_for_each_replica(_merge_raises_fn) 307 with self.assertRaises(_TestException): 308 dist.extended.call_for_each_replica(_merge_call_raises_fn) 309 with self.assertRaises(_TestException): 310 dist.extended.call_for_each_replica(_merge_call_merge_raises_fn) 311 312 def _input_fn_to_test_input_context(self, dataset_or_callable_fn, 313 expected_num_replicas_in_sync, 314 expected_num_input_pipelines, 315 expected_input_pipeline_id): 316 # Use a list of one element as counter so that it can be captured by the 317 # `_input_fn`. This counter is incremented by 1 each time an input_fn is 318 # called. We use this counter to check whether the `input_pipeline_id` 319 # matches the counter in the in-graph replication. 320 worker_id_counter = [0] 321 322 def _input_fn(input_context): 323 """Input fn for testing.""" 324 self.assertIsNotNone(input_context) 325 self.assertEqual(expected_num_replicas_in_sync, 326 input_context.num_replicas_in_sync) 327 self.assertEqual(expected_num_input_pipelines, 328 input_context.num_input_pipelines) 329 if expected_input_pipeline_id is not None: 330 self.assertEqual(expected_input_pipeline_id, 331 input_context.input_pipeline_id) 332 else: 333 self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id) 334 worker_id_counter[0] += 1 335 336 return dataset_or_callable_fn() 337 338 return _input_fn 339 340 def _test_input_fn_iterable( 341 self, strategy, input_fn, expected_values, ignore_order=False): 342 assert_same = self.assertCountEqual if ignore_order else self.assertEqual 343 344 iterable = strategy.distribute_datasets_from_function(input_fn) 345 if context.executing_eagerly(): 346 iterator = iter(iterable) 347 348 for expected_value in expected_values: 349 computed_value = self.evaluate( 350 list(strategy.experimental_local_results(next(iterator)))) 351 assert_same(expected_value, computed_value) 352 353 with self.assertRaises(StopIteration): 354 self.evaluate(strategy.experimental_local_results(next(iterator))) 355 356 # After re-initializing the iterator, should be able to iterate again. 357 iterator = iter(iterable) 358 359 for expected_value in expected_values: 360 computed_value = self.evaluate( 361 list(strategy.experimental_local_results(next(iterator)))) 362 assert_same(expected_value, computed_value) 363 else: 364 iterator = dataset_ops.make_initializable_iterator(iterable) 365 self._test_input_fn_iterator(iterator, strategy.extended.worker_devices, 366 expected_values, test_reinitialize=True, 367 ignore_order=ignore_order) 368 369 def _test_input_fn_iterator(self, 370 iterator, 371 devices, 372 expected_values, 373 sess=None, 374 test_reinitialize=True, 375 ignore_order=False): 376 evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) 377 evaluate(iterator.initializer) 378 379 for expected_value in expected_values: 380 next_element = iterator.get_next() 381 computed_value = evaluate( 382 [distribute_utils.select_replica(r, next_element) for r in 383 range(len(devices))]) 384 if ignore_order: 385 self.assertCountEqual(expected_value, computed_value) 386 else: 387 self.assertEqual(expected_value, computed_value) 388 389 with self.assertRaises(errors.OutOfRangeError): 390 next_element = iterator.get_next() 391 evaluate( 392 [distribute_utils.select_replica(r, next_element) for r in 393 range(len(devices))]) 394 395 # After re-initializing the iterator, should be able to iterate again. 396 if test_reinitialize: 397 evaluate(iterator.initializer) 398 399 for expected_value in expected_values: 400 next_element = iterator.get_next() 401 computed_value = evaluate([ 402 distribute_utils.select_replica(r, next_element) for r in 403 range(len(devices)) 404 ]) 405 if ignore_order: 406 self.assertCountEqual(expected_value, computed_value) 407 else: 408 self.assertEqual(expected_value, computed_value) 409 410 def _test_global_step_update(self, strategy): 411 with strategy.scope(): 412 global_step = variable_scope.get_variable( 413 "global_step", 414 shape=[], 415 dtype=dtypes.int64, 416 initializer=init_ops.zeros_initializer(), 417 trainable=False, 418 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) 419 self.evaluate(variables.global_variables_initializer()) 420 421 def model_fn(): 422 train_op = global_step.assign_add(1) 423 value = global_step.read_value() 424 return train_op, value 425 426 train_ops, value = strategy.extended.call_for_each_replica(model_fn) 427 self.evaluate(strategy.group(train_ops)) 428 global_step_tensors = strategy.experimental_local_results(value) 429 global_step_values = self.evaluate(global_step_tensors) 430 self.assertEqual((1,) * len(global_step_tensors), global_step_values) 431 432 def _test_numpy_dataset(self, strategy, session=None, run_in_function=False): 433 if not isinstance(strategy, distribute_lib.StrategyV1): 434 self.skipTest("n/a: V1 only") 435 cached_session = session or self.cached_session() 436 with strategy.scope(), cached_session as sess: 437 x = np.asarray([[1, 2], [6, 12], [2, 4], [5, 10], [3, 6], [4, 8]]) 438 y = np.asarray([5, 4, 3, 2, 1, 0]) 439 batch_size = 6 440 if not strategy.extended._global_batch_size: # pylint: disable=protected-access 441 batch_size = batch_size // strategy.num_replicas_in_sync 442 443 ds = strategy.extended.experimental_make_numpy_dataset( 444 (x, y), session=sess or self.cached_session()) 445 ds = ds.repeat(2) # 2 epochs 446 # We need to use the drop_remainder argument to get a known static 447 # input shape which is required for TPUs. 448 drop_remainder = strategy.extended.experimental_require_static_shapes 449 ds = ds.batch(batch_size, drop_remainder=drop_remainder) 450 i = strategy.make_dataset_iterator(ds) 451 452 self.evaluate(i.initializer) 453 454 def run_and_concatenate(strategy, i): 455 x, y = strategy.experimental_run( 456 _maybe_run_in_function(lambda z: z, run_in_function), i) 457 x, y = self.evaluate((strategy.experimental_local_results(x), 458 strategy.experimental_local_results(y))) 459 return np.concatenate(x), np.concatenate(y) 460 461 x_1, y_1 = run_and_concatenate(strategy, i) 462 self.assertAllEqual(x, x_1) 463 self.assertAllEqual(y, y_1) 464 x_2, y_2 = run_and_concatenate(strategy, i) 465 self.assertAllEqual(x, x_2) 466 self.assertAllEqual(y, y_2) 467 with self.assertRaises(errors.OutOfRangeError): 468 run_and_concatenate(strategy, i) 469 470 def _test_trainable_variable(self, strategy): 471 for cls in [variables.VariableV1, variables.Variable]: 472 with strategy.scope(): 473 v1 = cls(1.0) 474 self.assertEqual(True, v1.trainable) 475 476 v2 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ) 477 self.assertEqual(False, v2.trainable) 478 479 v3 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ, 480 trainable=True) 481 self.assertEqual(True, v3.trainable) 482 483 v4 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ, 484 trainable=False) 485 self.assertEqual(False, v4.trainable) 486 487 488class OneDeviceDistributionTestBase(test.TestCase): 489 """Some tests that should work with any one-device DistributionStrategy.""" 490 491 def _test_run(self, strategy): 492 out1 = strategy.run(lambda: array_ops.identity(4.)) 493 self.assertAllEqual([4.], self.evaluate(strategy.unwrap(out1))) 494 495 out2 = strategy.run(lambda x: {"a": x * 2, "b": x * x}, args=(out1,)) 496 out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2)) 497 self.assertAllEqual([8.], out2_vals["a"]) 498 self.assertAllEqual([16.], out2_vals["b"]) 499 500 out3 = strategy.run(lambda b, a: a + 2 * b + 2, kwargs=out2) 501 self.assertAllEqual([42.], self.evaluate(strategy.unwrap(out3))) 502 503 def _test_all_reduce_sum(self, strategy): 504 self._test_collective_comms( 505 strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.])) 506 507 def _test_all_reduce_sum_gradients(self, strategy): 508 self._test_collective_comms_gradients( 509 strategy, _all_sum, inputs=[4.], expected_grads=[4.]) 510 511 def _test_all_reduce_sum_gradient_tape(self, strategy): 512 self._test_collective_comms_gradient_tape( 513 strategy, _all_sum, inputs=[4.], expected_grads=[4.]) 514 515 def _test_all_reduce_mean(self, strategy): 516 self._test_collective_comms( 517 strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.])) 518 519 def _test_all_reduce_mean_gradients(self, strategy): 520 self._test_collective_comms_gradients( 521 strategy, _all_mean, inputs=[5.], expected_grads=[5.]) 522 523 def _test_all_reduce_mean_gradient_tape(self, strategy): 524 self._test_collective_comms_gradient_tape( 525 strategy, _all_mean, inputs=[5.], expected_grads=[5.]) 526 527 def _test_collective_comms(self, strategy, comm_fn, inputs, expected): 528 inputs = strategy.make_input_fn_iterator( 529 lambda _: dataset_ops.Dataset.from_tensors(inputs)) 530 531 self.evaluate(inputs.initialize()) 532 outputs = self.evaluate( 533 list( 534 map(strategy.experimental_local_results, 535 strategy.experimental_run(comm_fn, inputs)))) 536 self.assertAllEqual([expected[0]], outputs[0]) 537 self.assertAllEqual([expected[1]], outputs[1]) 538 539 def _test_collective_comms_gradients(self, strategy, comm_fn, inputs, 540 expected_grads): 541 if context.executing_eagerly(): 542 self.skipTest("`tf.gradients` is not supported with eager execution.") 543 544 def step(c): 545 x = array_ops.identity(42.) 546 y = comm_fn(x) * c 547 return gradients_impl.gradients(y, [x])[0] 548 549 inputs = strategy.make_input_fn_iterator( 550 lambda _: dataset_ops.Dataset.from_tensors(inputs)) 551 552 self.evaluate(inputs.initialize()) 553 self.assertAllEqual( 554 expected_grads, 555 self.evaluate( 556 strategy.experimental_local_results( 557 strategy.experimental_run(step, inputs)))) 558 559 def _test_collective_comms_gradient_tape(self, strategy, comm_fn, inputs, 560 expected_grads): 561 562 def step(c): 563 x = array_ops.identity(42.) 564 with backprop.GradientTape() as tape: 565 tape.watch(x) 566 y = comm_fn(x) * c 567 return tape.gradient(y, x) 568 569 inputs = strategy.make_input_fn_iterator( 570 lambda _: dataset_ops.Dataset.from_tensors(inputs)) 571 572 self.evaluate(inputs.initialize()) 573 self.assertAllEqual( 574 expected_grads, 575 self.evaluate( 576 strategy.experimental_local_results( 577 strategy.experimental_run(step, inputs)))) 578 579 def _test_device_and_input_device_are_colocated(self, strategy): 580 if context.executing_eagerly(): 581 self.skipTest( 582 "cross-device tests are not supported with eager execution.") 583 workers, _ = test_util.create_local_cluster(2, 0) 584 inputs = strategy.make_input_fn_iterator( 585 lambda _: dataset_ops.Dataset.range(5)) 586 comm_fn = lambda x: x + 1 587 run_op = strategy.experimental_run(comm_fn, inputs) 588 with session_lib.Session(target=workers[1].target) as sess: 589 sess.run(inputs.initialize()) 590 sess.run(run_op) 591 592 def _test_device_and_input_device_are_colocated_with_function(self, strategy): 593 if context.executing_eagerly(): 594 self.skipTest( 595 "cross-device tests are not supported with eager execution.") 596 workers, _ = test_util.create_local_cluster(2, 0) 597 inputs = strategy.make_input_fn_iterator( 598 lambda _: dataset_ops.Dataset.range(5)) 599 comm_fn = lambda x: x + 1 600 experimental_run = def_function.function()(strategy.experimental_run) 601 with ops.device("/job:worker/replica:0/task:1/device:CPU:0"): 602 # The tf.function must be defined on the right device as well. 603 run_op = experimental_run(comm_fn, inputs) 604 with session_lib.Session(target=workers[1].target) as sess: 605 sess.run(inputs.initialize()) 606 sess.run(run_op) 607 608 609class TwoDeviceDistributionTestBase(test.TestCase): 610 """Some tests that should work with any two-device DistributionStrategy.""" 611 612 def _test_run(self, strategy, run_in_function=False): 613 out1 = strategy.run(_maybe_run_in_function( 614 lambda: ds_context.get_replica_context().replica_id_in_sync_group + 1, 615 run_in_function)) 616 self.assertAllEqual([1, 2], self.evaluate(strategy.unwrap(out1))) 617 618 out2 = strategy.run(_maybe_run_in_function( 619 lambda x: {"a": x * 2, "b": x * x}, run_in_function), args=(out1,)) 620 out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2)) 621 self.assertAllEqual([2, 4], out2_vals["a"]) 622 self.assertAllEqual([1, 4], out2_vals["b"]) 623 624 out3 = strategy.run(_maybe_run_in_function( 625 lambda b, a: a + 2 * b + 2, run_in_function), kwargs=out2) 626 self.assertAllEqual([6, 14], self.evaluate(strategy.unwrap(out3))) 627 628 def _test_all_reduce_sum(self, strategy, run_in_function=False): 629 self._test_collective_comms( 630 strategy, 631 _all_sum, 632 inputs=([1., 3.], [[39., 2.], [3., 41.]]), 633 expected=(4., [42., 43.]), 634 run_in_function=run_in_function) 635 636 def _test_all_reduce_sum_gradients(self, strategy, run_in_function=False): 637 self._test_collective_comms_gradients( 638 strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.], 639 run_in_function=run_in_function) 640 641 def _test_all_reduce_sum_gradient_tape(self, strategy, run_in_function=False): 642 self._test_collective_comms_gradient_tape( 643 strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.], 644 run_in_function=run_in_function) 645 646 def _test_all_reduce_mean(self, strategy, run_in_function=False): 647 self._test_collective_comms( 648 strategy, 649 _all_mean, 650 inputs=([1., 3.], [[39., 2.], [3., 41.]]), 651 expected=(2., [21., 21.5]), 652 run_in_function=run_in_function) 653 654 def _test_all_reduce_mean_gradients(self, strategy, run_in_function=False): 655 self._test_collective_comms_gradients( 656 strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.], 657 run_in_function=run_in_function) 658 659 def _test_all_reduce_mean_gradient_tape(self, strategy, 660 run_in_function=False): 661 self._test_collective_comms_gradient_tape( 662 strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.], 663 run_in_function=run_in_function) 664 665 def _test_collective_comms(self, strategy, comm_fn, inputs, expected, 666 run_in_function=False): 667 inputs = strategy.make_input_fn_iterator( 668 lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) 669 670 self.evaluate(inputs.initialize()) 671 outputs = self.evaluate( 672 list( 673 map(strategy.experimental_local_results, 674 strategy.experimental_run( 675 _maybe_run_in_function(comm_fn, run_in_function), inputs)))) 676 self.assertAllEqual([expected[0], expected[0]], outputs[0]) 677 self.assertAllEqual([expected[1], expected[1]], outputs[1]) 678 679 def _test_collective_comms_gradients(self, strategy, comm_fn, inputs, 680 expected_grads, run_in_function=False): 681 if context.executing_eagerly() and not run_in_function: 682 self.skipTest("`tf.gradients` is not supported with eager execution " 683 "without using tf.functions.") 684 685 def step(c): 686 x = array_ops.identity(42.) 687 y = comm_fn(x) * c 688 return gradients_impl.gradients(y, [x])[0] 689 690 inputs = strategy.make_input_fn_iterator( 691 lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) 692 693 self.evaluate(inputs.initialize()) 694 self.assertAllEqual( 695 expected_grads, 696 self.evaluate( 697 strategy.experimental_local_results( 698 strategy.experimental_run( 699 _maybe_run_in_function(step, run_in_function), inputs)))) 700 701 def _test_collective_comms_gradient_tape(self, strategy, comm_fn, inputs, 702 expected_grads, 703 run_in_function=False): 704 705 def step(c): 706 x = array_ops.identity(42.) 707 with backprop.GradientTape() as tape: 708 tape.watch(x) 709 y = comm_fn(x) * c 710 return tape.gradient(y, x) 711 712 inputs = strategy.make_input_fn_iterator( 713 lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) 714 715 self.evaluate(inputs.initialize()) 716 self.assertAllEqual( 717 expected_grads, 718 self.evaluate( 719 strategy.experimental_local_results( 720 strategy.experimental_run( 721 _maybe_run_in_function(step, run_in_function), 722 inputs)))) 723 724 725class RemoteSingleWorkerMirroredStrategyBase(DistributionTestBase): 726 """Tests for a Remote single worker.""" 727 728 def _get_num_gpus(self): 729 pass 730 731 def _testNumReplicasInSync(self, distribution): 732 self.assertEqual(self._get_num_gpus(), distribution.num_replicas_in_sync) 733 734 def _testMinimizeLoss(self, distribution): 735 if context.executing_eagerly(): 736 self._test_minimize_loss_eager(distribution) 737 else: 738 self._test_minimize_loss_graph(distribution, learning_rate=0.05) 739 740 def _testDeviceScope(self, distribution): 741 with distribution.scope(): 742 a = array_ops.identity(1.) 743 with ops.device("/cpu:0"): 744 b = array_ops.identity(1.) 745 if context.executing_eagerly(): 746 device = "/job:worker/replica:0/task:0/device:CPU:0" 747 else: 748 device = "/job:worker/replica:0/task:0" 749 self.assertEqual(a.device, device) 750 self.assertEqual(b.device, "/job:worker/replica:0/task:0/device:CPU:0") 751 752 def _testMakeInputFnIteratorWithDataset(self, distribution): 753 dataset_fn = lambda: dataset_ops.Dataset.range(100) 754 num_gpus = self._get_num_gpus() 755 num_workers = 1 756 757 expected_values = [[i+j for j in range(num_gpus)] * num_workers 758 for i in range(0, 100, num_gpus)] 759 760 # Dummy cached_session is used in Eager 761 with self.cached_session() as sess: 762 # `expected_input_pipeline_id` is None because the input_fn will be called 763 # multiple times, each with a different input_pipeline_id. 764 input_fn = self._input_fn_to_test_input_context( 765 dataset_fn, 766 expected_num_replicas_in_sync=num_workers*num_gpus, 767 expected_num_input_pipelines=num_workers, 768 expected_input_pipeline_id=None) 769 iterator = distribution.make_input_fn_iterator(input_fn) 770 self._test_input_fn_iterator( 771 iterator, distribution.extended.worker_devices, expected_values, sess) 772 773 def _testMakeInputFnIteratorWithCallable(self, distribution): 774 def fn(): 775 dataset = dataset_ops.Dataset.range(100) 776 it = dataset_ops.make_one_shot_iterator(dataset) 777 return it.get_next 778 num_gpus = self._get_num_gpus() 779 num_workers = 1 780 781 expected_values = [] 782 for i in range(0, 100, num_gpus): 783 expected_values.append([i+j for j in range(num_gpus)] * num_workers) 784 785 # Dummy cached_session is used in Eager 786 with self.cached_session() as sess: 787 # `expected_input_pipeline_id` is None because the input_fn will be called 788 # multiple times, each with a different input_pipeline_id. 789 input_fn = self._input_fn_to_test_input_context( 790 fn, 791 expected_num_replicas_in_sync=num_workers*num_gpus, 792 expected_num_input_pipelines=num_workers, 793 expected_input_pipeline_id=None) 794 iterator = distribution.make_input_fn_iterator(input_fn) 795 self._test_input_fn_iterator( 796 iterator, distribution.extended.worker_devices, expected_values, sess, 797 test_reinitialize=False, ignore_order=True) 798 799 800def _all_sum(value): 801 ctx = ds_context.get_replica_context() 802 return ctx.all_reduce(reduce_util.ReduceOp.SUM, value) 803 804 805def _all_mean(value): 806 ctx = ds_context.get_replica_context() 807 return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value) 808