1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Strategy combinations for combinations.combine().""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python import tf2 22from tensorflow.python.distribute import central_storage_strategy 23from tensorflow.python.distribute import cluster_resolver 24from tensorflow.python.distribute import collective_all_reduce_strategy 25from tensorflow.python.distribute import combinations 26from tensorflow.python.distribute import distribution_strategy_context 27from tensorflow.python.distribute import mirrored_strategy as mirrored_lib 28from tensorflow.python.distribute import multi_process_runner 29from tensorflow.python.distribute import multi_worker_test_base 30from tensorflow.python.distribute import one_device_strategy as one_device_lib 31from tensorflow.python.distribute import test_util 32from tensorflow.python.distribute import tpu_strategy as tpu_lib 33from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 34from tensorflow.python.eager import context 35from tensorflow.python.eager import remote 36from tensorflow.python.platform import flags 37from tensorflow.python.tpu import device_assignment as device_assignment_lib 38from tensorflow.python.tpu import tpu_strategy_util 39from tensorflow.python.util.tf_export import tf_export 40 41_TF_INTERNAL_API_PREFIX = "__internal__.distribute.combinations." 42 43_did_connect_to_cluster = False 44_topology = None 45CollectiveAllReduceExtended = ( 46 collective_all_reduce_strategy.CollectiveAllReduceExtended) 47 48 49def _version_chooser(tf1_cls, tf2_cls): 50 51 def creator(*args, **kwargs): 52 if tf2.enabled(): 53 return tf2_cls(*args, **kwargs) 54 return tf1_cls(*args, **kwargs) 55 56 return creator 57 58 59MirroredStrategy = _version_chooser(mirrored_lib.MirroredStrategyV1, 60 mirrored_lib.MirroredStrategy) 61CentralStorageStrategy = _version_chooser( 62 central_storage_strategy.CentralStorageStrategyV1, 63 central_storage_strategy.CentralStorageStrategy) 64OneDeviceStrategy = _version_chooser(one_device_lib.OneDeviceStrategyV1, 65 one_device_lib.OneDeviceStrategy) 66# Only V2 CollectiveAllReduceStrategy combinations are supported. 67CollectiveAllReduceStrategy = ( 68 collective_all_reduce_strategy.CollectiveAllReduceStrategy) 69 70 71# pylint: disable=missing-docstring 72def _get_tpu_strategy_creator(steps_per_run, 73 use_single_core=False, 74 enable_packed_variable=False, 75 **kwargs): 76 77 def _create_tpu_strategy(): 78 FLAGS = flags.FLAGS # pylint: disable=invalid-name 79 global _did_connect_to_cluster 80 global _topology 81 82 try: 83 # Attempt to locally discover the TPU. This will fail for Cloud TPU, in 84 # which case we fall back to the values passed as flags. 85 resolver = tpu_cluster_resolver.TPUClusterResolver() 86 did_automatically_resolve = True 87 except ValueError: 88 did_automatically_resolve = False 89 90 # These flags will be defined by tpu_test_wrapper.py. 91 resolver = tpu_cluster_resolver.TPUClusterResolver( 92 tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "", 93 zone=hasattr(FLAGS, "zone") and FLAGS.zone or None, 94 project=hasattr(FLAGS, "project") and FLAGS.project or None, 95 ) 96 97 # Only connect once per process, rather than per test method. 98 if not _did_connect_to_cluster: 99 if getattr(FLAGS, "tpu", "") or did_automatically_resolve: 100 remote.connect_to_cluster(resolver) 101 _did_connect_to_cluster = True 102 _topology = tpu_strategy_util.initialize_tpu_system(resolver) 103 104 device_assignment = None 105 if use_single_core: 106 device_assignment = device_assignment_lib.DeviceAssignment( 107 _topology, 108 core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT) 109 110 # Steps per run is only supported in TF 1.x 111 if tf2.enabled(): 112 strategy = tpu_lib.TPUStrategy(resolver, device_assignment, **kwargs) 113 else: 114 strategy = tpu_lib.TPUStrategyV1(resolver, steps_per_run, 115 device_assignment, **kwargs) 116 strategy._enable_packed_variable_in_eager_mode = enable_packed_variable # pylint: disable=protected-access 117 return strategy 118 119 return _create_tpu_strategy 120 121 122def _mirrored_strategy_with_collective_key_base(devices): 123 mirrored_lib.MirroredStrategyV1._collective_key_base += 100000 124 mirrored_lib.MirroredStrategy._collective_key_base += 100000 125 return MirroredStrategy(devices) 126 127 128def _get_multi_worker_mirrored_creator(required_gpus): 129 130 def _create_multi_worker_mirrored(): 131 tf_config = cluster_resolver.TFConfigClusterResolver() 132 master = tf_config.master() 133 if tf_config.rpc_layer: 134 # Strip off the rpc_layer suffix. 135 master = master[len("%s://" % tf_config.rpc_layer):] 136 resolver = cluster_resolver.SimpleClusterResolver( 137 cluster_spec=tf_config.cluster_spec(), 138 task_type=tf_config.task_type, 139 task_id=tf_config.task_id, 140 master=master, 141 environment=tf_config.environment, 142 num_accelerators={"GPU": required_gpus}, 143 rpc_layer=tf_config.rpc_layer or "grpc", 144 ) 145 # Disable health check. We don't have a reliable to shutdown the strategy 146 # (and thus the health check) at the end of a test. Turning on health check 147 # causes some flakiness since we re-create part of the server when creating 148 # a strategy, and our tests are capable of handling failures. 149 CollectiveAllReduceExtended._enable_check_health = False # pylint: disable=protected-access 150 # Always create the strategy in eager mode so that it starts the server and 151 # configures the eager context. The eager context can no longer be 152 # configured after initialization. 153 with context.eager_mode(): 154 strategy = CollectiveAllReduceStrategy(cluster_resolver=resolver) 155 # TODO(b/152320929): Wait for the cluster before proceeding, otherwise 156 # collectives may hang if any worker launches collectives before the chief 157 # creates the strategy. 158 try: 159 multi_process_runner.get_barrier().wait() 160 except ValueError: 161 # If the creator is called in the main process, 162 # multi_process_runner.get_barrier() raises ValueError, which is safe to 163 # ignore. 164 pass 165 return strategy 166 167 return _create_multi_worker_mirrored 168 169 170def _deferred_pool_runner(has_chief, num_workers, initializer=None): 171 """Returns a callable that returns the pool runner. 172 173 It creates the pool runner only upon first invocation. This avoids creating it 174 when this file is imported. 175 176 Args: 177 has_chief: whether there should be a chief. 178 num_workers: the number of workers excluding the chief. 179 initializer: initializer of each process. 180 181 Returns: 182 A callable that returns the runner. 183 """ 184 185 container = [] 186 187 def get_or_create(): 188 if not container: 189 cluster_spec = multi_worker_test_base.create_cluster_spec( 190 has_chief=has_chief, 191 num_workers=num_workers, 192 num_ps=0, 193 has_eval=False) 194 runner = multi_process_runner.MultiProcessPoolRunner( 195 cluster_spec, initializer=initializer) 196 container.append(runner) 197 return container[0] 198 199 return get_or_create 200 201 202# We need to create the strategy in the initializer to start the server before 203# any test runs. 204_two_worker_pool = _deferred_pool_runner( 205 has_chief=True, 206 num_workers=1, 207 initializer=_get_multi_worker_mirrored_creator(required_gpus=0)) 208_four_worker_pool = _deferred_pool_runner( 209 has_chief=True, 210 num_workers=3, 211 initializer=_get_multi_worker_mirrored_creator(required_gpus=0)) 212 213 214# pylint: disable=g-long-lambda 215default_strategy = combinations.NamedDistribution( 216 "Default", 217 distribution_strategy_context._get_default_strategy, # pylint: disable=protected-access 218 required_gpus=None) 219one_device_strategy = combinations.NamedDistribution( 220 "OneDeviceCPU", lambda: OneDeviceStrategy("/cpu:0"), required_gpus=None) 221one_device_strategy_gpu = combinations.NamedDistribution( 222 "OneDeviceGPU", lambda: OneDeviceStrategy("/gpu:0"), required_gpus=1) 223one_device_strategy_on_worker_1 = combinations.NamedDistribution( 224 "OneDeviceOnWorker1CPU", 225 lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/cpu:0"), 226 required_gpus=None) 227one_device_strategy_gpu_on_worker_1 = combinations.NamedDistribution( 228 "OneDeviceOnWorker1GPU", 229 lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/gpu:0"), 230 required_gpus=1) 231tpu_strategy = combinations.NamedDistribution( 232 "TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True) 233tpu_strategy_packed_var = combinations.NamedDistribution( 234 "TPUPackedVar", 235 _get_tpu_strategy_creator(steps_per_run=2, enable_packed_variable=True), 236 required_tpu=True) 237tpu_strategy_one_step = combinations.NamedDistribution( 238 "TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), required_tpu=True) 239tpu_strategy_one_core = combinations.NamedDistribution( 240 "TPUOneCore", 241 _get_tpu_strategy_creator(steps_per_run=2, use_single_core=True), 242 required_tpu=True) 243tpu_strategy_one_step_one_core = combinations.NamedDistribution( 244 "TPUOneStepOneCore", 245 _get_tpu_strategy_creator(steps_per_run=1, use_single_core=True), 246 required_tpu=True) 247cloud_tpu_strategy = combinations.NamedDistribution( 248 "CloudTPU", 249 _get_tpu_strategy_creator(steps_per_run=2), 250 required_tpu=True, 251 use_cloud_tpu=True) 252mirrored_strategy_with_one_cpu = combinations.NamedDistribution( 253 "Mirrored1CPU", 254 lambda: _mirrored_strategy_with_collective_key_base(["/cpu:0"])) 255mirrored_strategy_with_one_gpu = combinations.NamedDistribution( 256 "Mirrored1GPU", 257 lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0"]), 258 required_gpus=1) 259mirrored_strategy_with_gpu_and_cpu = combinations.NamedDistribution( 260 "MirroredCPUAndGPU", 261 lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0", "/cpu:0"]), 262 required_gpus=1) 263mirrored_strategy_with_two_gpus = combinations.NamedDistribution( 264 "Mirrored2GPUs", 265 lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0", "/gpu:1"]), 266 required_gpus=2) 267# Should call set_virtual_cpus_to_at_least(3) in your test's setUp methods. 268mirrored_strategy_with_cpu_1_and_2 = combinations.NamedDistribution( 269 "Mirrored2CPU", 270 lambda: _mirrored_strategy_with_collective_key_base(["/cpu:1", "/cpu:2"])) 271mirrored_strategy_with_cpu_1_and_2.__doc__ = ( 272 """Mirrored strategy with 2 virtual CPUs. 273 274 Should set up logical devices before use 275 """) 276central_storage_strategy_with_two_gpus = combinations.NamedDistribution( 277 "CentralStorage2GPUs", 278 lambda: CentralStorageStrategy(["/gpu:0", "/gpu:1"]), 279 required_gpus=2) 280central_storage_strategy_with_gpu_and_cpu = combinations.NamedDistribution( 281 "CentralStorageCPUAndGPU", 282 lambda: CentralStorageStrategy(["/gpu:0", "/cpu:0"]), 283 required_gpus=1) 284# chief + 1 worker, with CPU. 285multi_worker_mirrored_2x1_cpu = combinations.NamedDistribution( 286 "MultiWorkerMirrored2x1CPU", 287 _get_multi_worker_mirrored_creator(required_gpus=0), 288 has_chief=True, 289 num_workers=1, 290 pool_runner_fn=_two_worker_pool, 291 no_xla=True, 292) 293# chief + 1 worker, with 1 GPU each. 294multi_worker_mirrored_2x1_gpu = combinations.NamedDistribution( 295 "MultiWorkerMirrored2x1GPU", 296 _get_multi_worker_mirrored_creator(required_gpus=1), 297 has_chief=True, 298 num_workers=1, 299 required_gpus=1, 300 pool_runner_fn=_two_worker_pool, 301 no_xla=True, 302) 303# chief + 1 worker, with 2 GPU each. 304multi_worker_mirrored_2x2_gpu = combinations.NamedDistribution( 305 "MultiWorkerMirrored2x2GPU", 306 _get_multi_worker_mirrored_creator(required_gpus=2), 307 has_chief=True, 308 num_workers=1, 309 required_gpus=2, 310 pool_runner_fn=_two_worker_pool, 311 no_xla=True, 312) 313# chief + 3 workers, with CPU. 314multi_worker_mirrored_4x1_cpu = combinations.NamedDistribution( 315 "MultiWorkerMirrored4x1CPU", 316 _get_multi_worker_mirrored_creator(required_gpus=0), 317 has_chief=True, 318 num_workers=3, 319 pool_runner_fn=_four_worker_pool, 320 no_xla=True, 321) 322 323 324graph_and_eager_modes = ["graph", "eager"] 325 326 327# TODO(crccw): remove after tf-nightly picks up the new API. 328def set_virtual_cpus_to_at_least(num_virtual_cpus): 329 test_util.set_logical_devices_to_at_least("CPU", num_virtual_cpus) 330 331 332strategies_minus_tpu = [ 333 default_strategy, 334 one_device_strategy, 335 one_device_strategy_gpu, 336 mirrored_strategy_with_gpu_and_cpu, 337 mirrored_strategy_with_two_gpus, 338 central_storage_strategy_with_gpu_and_cpu, 339] 340 341strategies_minus_default_and_tpu = [ 342 one_device_strategy, 343 one_device_strategy_gpu, 344 mirrored_strategy_with_gpu_and_cpu, 345 mirrored_strategy_with_two_gpus, 346] 347 348tpu_strategies = [ 349 tpu_strategy, # steps_per_run=2 350 tpu_strategy_one_step, 351 tpu_strategy_packed_var, 352 cloud_tpu_strategy, 353] 354 355all_strategies_minus_default = strategies_minus_default_and_tpu + tpu_strategies 356 357all_strategies = strategies_minus_tpu + tpu_strategies 358 359two_replica_strategies = [ 360 mirrored_strategy_with_gpu_and_cpu, 361 mirrored_strategy_with_two_gpus, 362 multi_worker_mirrored_2x1_cpu, 363 multi_worker_mirrored_2x1_gpu, 364 tpu_strategy, # steps_per_run=2 365 tpu_strategy_one_step, 366 central_storage_strategy_with_gpu_and_cpu, 367] 368 369four_replica_strategies = [ 370 multi_worker_mirrored_2x2_gpu, 371 multi_worker_mirrored_4x1_cpu, 372] 373 374# TODO(b/159831907): replace with two_replica_strategies after the tests using 375# it work with MWMS. 376multidevice_strategies = [ 377 mirrored_strategy_with_gpu_and_cpu, 378 mirrored_strategy_with_two_gpus, 379 tpu_strategy, # steps_per_run=2 380 tpu_strategy_one_step 381] 382 383multiworker_strategies = [ 384 multi_worker_mirrored_2x1_cpu, multi_worker_mirrored_2x1_gpu, 385 multi_worker_mirrored_2x2_gpu 386] 387 388 389def strategy_minus_tpu_combinations(): 390 return combinations.combine( 391 distribution=strategies_minus_tpu, mode=["graph", "eager"]) 392 393 394def tpu_strategy_combinations(): 395 return combinations.combine(distribution=tpu_strategies, mode=["graph"]) 396 397 398def all_strategy_combinations(): 399 return strategy_minus_tpu_combinations() + tpu_strategy_combinations() 400 401 402def all_strategy_minus_default_and_tpu_combinations(): 403 return combinations.combine( 404 distribution=[ 405 one_device_strategy, one_device_strategy_gpu, 406 mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus 407 ], 408 mode=["graph", "eager"]) 409 410 411def all_strategy_combinations_minus_default(): 412 return (all_strategy_minus_default_and_tpu_combinations() + 413 tpu_strategy_combinations()) 414 415 416tf_export( 417 _TF_INTERNAL_API_PREFIX + "central_storage_strategy_with_gpu_and_cpu", 418 v1=[]).export_constant(__name__, 419 "central_storage_strategy_with_gpu_and_cpu") 420tf_export( 421 _TF_INTERNAL_API_PREFIX + "central_storage_strategy_with_two_gpus", 422 v1=[]).export_constant(__name__, "central_storage_strategy_with_two_gpus") 423tf_export( 424 _TF_INTERNAL_API_PREFIX + "cloud_tpu_strategy", 425 v1=[]).export_constant(__name__, "cloud_tpu_strategy") 426tf_export( 427 _TF_INTERNAL_API_PREFIX + "default_strategy", 428 v1=[]).export_constant(__name__, "default_strategy") 429tf_export( 430 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_cpu_1_and_2", 431 v1=[]).export_constant(__name__, "mirrored_strategy_with_cpu_1_and_2") 432tf_export( 433 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_gpu_and_cpu", 434 v1=[]).export_constant(__name__, "mirrored_strategy_with_gpu_and_cpu") 435tf_export( 436 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_one_cpu", 437 v1=[]).export_constant(__name__, "mirrored_strategy_with_one_cpu") 438tf_export( 439 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_one_gpu", 440 v1=[]).export_constant(__name__, "mirrored_strategy_with_one_gpu") 441tf_export( 442 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_two_gpus", 443 v1=[]).export_constant(__name__, "mirrored_strategy_with_two_gpus") 444tf_export( 445 _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_cpu", 446 v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_cpu") 447tf_export( 448 _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_gpu", 449 v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_gpu") 450tf_export( 451 _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x2_gpu", 452 v1=[]).export_constant(__name__, "multi_worker_mirrored_2x2_gpu") 453tf_export( 454 _TF_INTERNAL_API_PREFIX + "one_device_strategy", 455 v1=[]).export_constant(__name__, "one_device_strategy") 456tf_export( 457 _TF_INTERNAL_API_PREFIX + "one_device_strategy_gpu", 458 v1=[]).export_constant(__name__, "one_device_strategy_gpu") 459tf_export( 460 _TF_INTERNAL_API_PREFIX + "tpu_strategy", 461 v1=[]).export_constant(__name__, "tpu_strategy") 462tf_export( 463 _TF_INTERNAL_API_PREFIX + "tpu_strategy_one_core", 464 v1=[]).export_constant(__name__, "tpu_strategy_one_core") 465tf_export( 466 _TF_INTERNAL_API_PREFIX + "tpu_strategy_packed_var", 467 v1=[]).export_constant(__name__, "tpu_strategy_packed_var") 468