1# Lint as: python3 2# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Parameter server strategy V2 class. 17 18This is currently under development and the API is subject to change. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import os 26 27from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 28from tensorflow.python.distribute import device_util 29from tensorflow.python.distribute import distribute_lib 30from tensorflow.python.distribute import input_lib 31from tensorflow.python.distribute import mirrored_run 32from tensorflow.python.distribute import multi_worker_util 33from tensorflow.python.distribute import parameter_server_strategy 34from tensorflow.python.distribute import ps_values 35from tensorflow.python.distribute import sharded_variable 36from tensorflow.python.distribute import values 37from tensorflow.python.eager import remote 38from tensorflow.python.framework import config 39from tensorflow.python.framework import device as tf_device 40from tensorflow.python.framework import ops 41from tensorflow.python.framework import tensor_shape 42from tensorflow.python.ops import variable_scope as vs 43from tensorflow.python.platform import tf_logging as logging 44from tensorflow.python.training import server_lib 45from tensorflow.python.training.tracking import base as trackable 46from tensorflow.python.util import nest 47from tensorflow.python.util import tf_inspect 48from tensorflow.python.util.tf_export import tf_export 49 50ALLOWED_TASK_TYPES = ("chief", "worker", "ps") 51 52 53@tf_export("distribute.experimental.ParameterServerStrategy", v1=[]) 54class ParameterServerStrategyV2(distribute_lib.Strategy): 55 """An multi-worker tf.distribute strategy with parameter servers. 56 57 Parameter server training is a common data-parallel method to scale up a 58 machine learning model on multiple machines. A parameter server training 59 cluster consists of workers and parameter servers. Variables are created on 60 parameter servers and they are read and updated by workers in each step. 61 By default, workers read and update these variables independently without 62 synchronizing with each other. Under this configuration, it is known as 63 asynchronous training. 64 65 In TensorFlow 2, we recommend an architecture based on central coordination 66 for parameter server training. Each worker and parameter server runs a 67 `tf.distribute.Server`, and on top of that, a coordinator task is responsible 68 for creating resources on workers and parameter servers, dispatching 69 functions, and coordinating the training. The coordinator uses a 70 `tf.distribute.experimental.coordinator.ClusterCoordinator` to coordinate the 71 cluster, and a `tf.distribute.experimental.ParameterServerStrategy` to define 72 variables on parameter servers and computation on workers. 73 74 For the training to work, the coordinator dispatches `tf.function`s to be 75 executed on remote workers. Upon receiving requests from the coordinator, a 76 worker executes the `tf.function` by reading the variables from parameter 77 servers, executing the ops, and updating the variables on the parameter 78 servers. Each of the worker only processes the requests from the coordinator, 79 and communicates with parameter servers, without direct interactions with 80 other workers in the cluster. 81 82 As a result, failures of some workers do not prevent the cluster from 83 continuing the work, and this allows the cluster to train with instances that 84 can be occasionally unavailable (e.g. preemptible or spot instances). The 85 coordinator and parameter servers though, must be available at all times for 86 the cluster to make progress. 87 88 Note that the coordinator is not one of the training workers. Instead, it 89 creates resources such as variables and datasets, dispatchs `tf.function`s, 90 saves checkpoints and so on. In addition to workers, parameter servers and 91 the coordinator, an optional evaluator can be run on the side that 92 periodically reads the checkpoints saved by the coordinator and runs 93 evaluations against each checkpoint. 94 95 `ParameterServerStrategy` is supported with two training APIs: [Custom 96 Training Loop (CTL)] 97 (https://www.tensorflow.org/tutorials/distribute/custom_training) 98 and [Keras Training API, also known as `Model.fit`] 99 (https://www.tensorflow.org/tutorials/distribute/keras). CTL is recommended 100 when users prefer to define the details of their training loop, and 101 `Model.fit` is recommended when users prefer a high-level abstraction and 102 handling of training. 103 104 When using a CTL, `ParameterServerStrategy` has to work in conjunction with a 105 `tf.distribute.experimental.coordinator.ClusterCoordinator` object. 106 107 When using `Model.fit`, currently only the 108 `tf.keras.utils.experimental.DatasetCreator` input type is supported. 109 110 __Example code for coordinator__ 111 112 This section provides code snippets that are intended to be run on (the only) 113 one task that is designated as the coordinator. Note that `cluster_resolver`, 114 `variable_partitioner`, and `dataset_fn` arguments are explained in the 115 following "Cluster setup", "Variable partitioning", and "Dataset preparation" 116 sections. 117 118 With a CTL, 119 120 ```python 121 # Prepare a strategy to use with the cluster and variable partitioning info. 122 strategy = tf.distribute.experimental.ParameterServerStrategy( 123 cluster_resolver=..., 124 variable_partitioner=...) 125 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 126 strategy=strategy) 127 128 # Prepare a distribute dataset that will place datasets on the workers. 129 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...) 130 131 with strategy.scope(): 132 model = ... 133 optimizer, metrics = ... # Keras optimizer/metrics are great choices 134 checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) 135 checkpoint_manager = tf.train.CheckpointManager( 136 checkpoint, checkpoint_dir, max_to_keep=2) 137 # `load_checkpoint` infers initial epoch from `optimizer.iterations`. 138 initial_epoch = load_checkpoint(checkpoint_manager) or 0 139 140 @tf.function 141 def worker_fn(iterator): 142 143 def replica_fn(inputs): 144 batch_data, labels = inputs 145 # calculate gradient, applying gradient, metrics update etc. 146 147 strategy.run(replica_fn, args=(next(iterator),)) 148 149 for epoch in range(initial_epoch, num_epoch): 150 distributed_iterator = iter(distributed_dataset) # Reset iterator state. 151 for step in range(steps_per_epoch): 152 153 # Asynchronously schedule the `worker_fn` to be executed on an arbitrary 154 # worker. This call returns immediately. 155 coordinator.schedule(worker_fn, args=(distributed_iterator,)) 156 157 # `join` blocks until all scheduled `worker_fn`s finish execution. Once it 158 # returns, we can read the metrics and save checkpoints as needed. 159 coordinator.join() 160 logging.info('Metric result: %r', metrics.result()) 161 train_accuracy.reset_states() 162 checkpoint_manager.save() 163 ``` 164 165 With `Model.fit`, 166 167 ```python 168 # Prepare a strategy to use with the cluster and variable partitioning info. 169 strategy = tf.distribute.experimental.ParameterServerStrategy( 170 cluster_resolver=..., 171 variable_partitioner=...) 172 173 # A dataset function takes a `input_context` and returns a `Dataset` 174 def dataset_fn(input_context): 175 dataset = tf.data.Dataset.from_tensors(...) 176 return dataset.repeat().shard(...).batch(...).prefetch(...) 177 178 # With `Model.fit`, a `DatasetCreator` needs to be used. 179 input = tf.keras.utils.experimental.DatasetCreator(dataset_fn=...) 180 181 with strategy.scope(): 182 model = ... # Make sure the `Model` is created within scope. 183 model.compile(optimizer="rmsprop", loss="mse", steps_per_execution=..., ...) 184 185 # Optional callbacks to checkpoint the model, back up the progress, etc. 186 callbacks = [tf.keras.callbacks.ModelCheckpoint(...), ...] 187 188 # `steps_per_epoch` is required with `ParameterServerStrategy`. 189 model.fit(input, epochs=..., steps_per_epoch=..., callbacks=callbacks) 190 ``` 191 192 __Example code for worker and parameter servers__ 193 194 In addition to the coordinator, there should be tasks designated as 195 "worker" or "ps". They should run the following code to start a TensorFlow 196 server, waiting for coordinator's requests: 197 198 ```python 199 # Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves 200 # the cluster information. See below "Cluster setup" section. 201 cluster_resolver = ... 202 203 server = tf.distribute.Server( 204 cluster_resolver.cluster_spec(), 205 job_name=cluster_resolver.task_type, 206 task_index=cluster_resolver.task_id, 207 protocol="grpc") 208 209 # Blocking the process that starts a server from exiting. 210 server.join() 211 ``` 212 213 __Cluster setup__ 214 215 In order for the tasks in the cluster to know other tasks' addresses, 216 a `tf.distribute.cluster_resolver.ClusterResolver` is required to be used 217 in coordinator, worker, and ps. The 218 `tf.distribute.cluster_resolver.ClusterResolver` is responsible for providing 219 the cluster information, as well as the task type and id of the current task. 220 See `tf.distribute.cluster_resolver.ClusterResolver` for more information. 221 222 If `TF_CONFIG` environment variable is set, a 223 `tf.distribute.cluster_resolver.TFConfigClusterResolver` should be used as 224 well. 225 226 Since there are assumptions in 227 `tf.distribute.experimental.ParameterServerStrategy` around the naming of the 228 task types, "chief", "ps", and "worker" should be used in the 229 `tf.distribute.cluster_resolver.ClusterResolver` to refer to the coordinator, 230 parameter servers, and workers, respectively. 231 232 The following example demonstrates setting `TF_CONFIG` for the task designated 233 as a parameter server (task type "ps") and index 1 (the second task), in a 234 cluster with 1 chief, 2 parameter servers, and 3 workers. Note that it needs 235 to be set before the use of 236 `tf.distribute.cluster_resolver.TFConfigClusterResolver`. 237 238 Example code for cluster setup: 239 ```python 240 os.environ['TF_CONFIG'] = ''' 241 { 242 "cluster": { 243 "chief": ["chief.example.com:2222"], 244 "ps": ["ps0.example.com:2222", "ps1.example.com:2222"], 245 "worker": ["worker0.example.com:2222", "worker1.example.com:2222", 246 "worker2.example.com:2222"] 247 }, 248 "task": { 249 "type": "ps", 250 "index": 1 251 } 252 } 253 ''' 254 ``` 255 256 If you prefer to run the same binary for all tasks, you will need to let the 257 binary branch into different roles at the beginning of the program: 258 ```python 259 # If coordinator, create a strategy and start the training program. 260 if cluster_resolver.task_type == 'chief': 261 strategy = tf.distribute.experimental.ParameterServerStrategy( 262 cluster_resolver) 263 ... 264 265 # If worker/ps, create a server 266 elif cluster_resolver.task_type in ("worker", "ps"): 267 server = tf.distribute.Server(...) 268 ... 269 ``` 270 Alternatively, you can also start a bunch of TensorFlow servers in advance and 271 connect to them later. The coordinator can be in the same cluster or on any 272 machine that has connectivity to workers and parameter servers. This is 273 covered in our guide and tutorial. 274 275 __Variable creation with `strategy.scope()`__ 276 277 `tf.distribute.experimental.ParameterServerStrategy` follows the 278 `tf.distribute` API contract where variable creation is expected to be inside 279 the context manager returned by `strategy.scope()`, in order to be correctly 280 placed on parameter servers in a round-robin manner: 281 282 ```python 283 # In this example, we're assuming having 3 ps. 284 strategy = tf.distribute.experimental.ParameterServerStrategy( 285 cluster_resolver=...) 286 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 287 strategy=strategy) 288 289 # Variables should be created inside scope to be placed on parameter servers. 290 # If created outside scope such as `v1` here, it would be placed on the 291 # coordinator. 292 v1 = tf.Variable(initial_value=0.0) 293 294 with strategy.scope(): 295 v2 = tf.Variable(initial_value=1.0) 296 v3 = tf.Variable(initial_value=2.0) 297 v4 = tf.Variable(initial_value=3.0) 298 v5 = tf.Variable(initial_value=4.0) 299 300 # v2 through v5 are created in scope and are distributed on parameter servers. 301 # Default placement is round-robin but the order should not be relied on. 302 assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0" 303 assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0" 304 assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0" 305 assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0" 306 ``` 307 308 See `distribute.Strategy.scope` for more information. 309 310 __Variable partitioning__ 311 312 Having dedicated servers to store variables means being able to divide up, or 313 "shard" the variables across the ps. Partitioning large variable among ps is a 314 commonly used technique to boost training throughput and mitigate memory 315 constraints. It enables parallel computations and updates on different shards 316 of a variable, and often yields better load balancing across parameter 317 servers. Without sharding, models with large variables (e.g, embeddings) that 318 can't fit into one machine's memory would otherwise be unable to train. 319 320 With `tf.distribute.experimental.ParameterServerStrategy`, if a 321 `variable_partitioner` is provided to `__init__` and certain conditions are 322 satisfied, the resulting variables created in scope are sharded across the 323 parameter servers, in a round-robin fashion. The variable reference returned 324 from `tf.Variable` becomes a type that serves as the container of the sharded 325 variables. One can access `variables` attribute of this container for the 326 actual variable components. If building model with `tf.Module` or Keras, 327 the variable components are collected in the `variables` alike attributes. 328 329 It is recommended to use size-based partitioners like 330 `tf.distribute.experimental.partitioners.MinSizePartitioner` to avoid 331 partitioning small variables, which could have negative impact on model 332 training speed. 333 334 ```python 335 # Partition the embedding layer into 2 shards. 336 variable_partitioner = ( 337 tf.distribute.experimental.partitioners.MinSizePartitioner( 338 min_shard_bytes=(256 << 10), 339 max_shards = 2)) 340 strategy = tf.distribute.experimental.ParameterServerStrategy( 341 cluster_resolver=..., 342 variable_partitioner = variable_partitioner) 343 with strategy.scope(): 344 embedding = tf.keras.layers.Embedding(input_dim=1024, output_dim=1024) 345 assert len(embedding.variables) == 2 346 assert isinstance(embedding.variables[0], tf.Variable) 347 assert isinstance(embedding.variables[1], tf.Variable) 348 assert embedding.variables[0].shape == (512, 1024) 349 assert embedding.variables[1].shape == (512, 1024) 350 ``` 351 352 The sharded variable container can be converted to a `Tensor` via 353 `tf.convert_to_tensor`. This means the container can be directly used in most 354 Python Ops where such `Tensor` conversion automatically happens. For example, 355 in the above code snippet, `x * self.w` would implicitly apply the said tensor 356 conversion. Note that such conversion can be expensive, as the variable 357 components need to be transferred from multiple parameter servers to where 358 the value is used. 359 360 `tf.nn.embedding_lookup` on the other hand doesn't apply the tensor 361 conversion, and performs parallel lookups on the variable components instead. 362 This is crucial to scale up embedding lookups when the embedding table 363 variable is large. 364 365 When a partitioned variable is saved to a `SavedModel`, it will be saved as if 366 it is one single variable. This improves serving efficiency by eliminating 367 a number of Ops that handle the partiton aspects. 368 369 Known limitations of variable partitioning: 370 371 * Number of partitions must not change across Checkpoint saving/loading. 372 373 * After saving partitioned variables to a SavedModel, the SavedModel can't be 374 loaded via `tf.saved_model.load`. 375 376 * Partition variable doesn't directly work with `tf.GradientTape`, please use 377 the `variables` attributes to get the actual variable components and use 378 them in gradient APIs instead. 379 380 __Dataset preparation__ 381 382 With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is 383 created in each of the workers to be used for training. This is done by 384 creating a `dataset_fn` that takes no argument and returns a 385 `tf.data.Dataset`, and passing the `dataset_fn` into 386 `tf.distribute.experimental.coordinator. 387 ClusterCoordinator.create_per_worker_dataset`. We recommend the dataset to be 388 shuffled and repeated to have the examples run through the training as evenly 389 as possible. 390 391 ```python 392 def dataset_fn(): 393 filenames = ... 394 dataset = tf.data.Dataset.from_tensor_slices(filenames) 395 396 # Dataset is recommended to be shuffled, and repeated. 397 return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...) 398 399 coordinator = 400 tf.distribute.experimental.coordinator.ClusterCoordinator(strategy=...) 401 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn) 402 ``` 403 404 __Limitations__ 405 406 * `tf.distribute.experimental.ParameterServerStrategy` in TF2 is experimental, 407 and the API is subject to further changes. 408 409 * When using `Model.fit`, `tf.distribute.experimental.ParameterServerStrategy` 410 must be used with a `tf.keras.utils.experimental.DatasetCreator`, and 411 `steps_per_epoch` must be specified. 412 """ 413 414 # pyformat: disable 415 def __init__(self, cluster_resolver, variable_partitioner=None): 416 """Initializes the TF2 parameter server strategy. 417 418 This initializes the `tf.distribute.experimental.ParameterServerStrategy` 419 object to be ready for use with 420 `tf.distribute.experimental.coordinator.ClusterCoordinator`. 421 422 Args: 423 cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver` 424 object. 425 variable_partitioner: 426 a `distribute.experimental.partitioners.Partitioner` that specifies 427 how to partition variables. If `None`, variables will not be 428 partitioned. 429 430 * Predefined partitioners in `tf.distribute.experimental.partitioners` 431 can be used for this argument. A commonly used partitioner is 432 `MinSizePartitioner(min_shard_bytes = 256 << 10, max_shards = num_ps)`, 433 which allocates at least 256K per shard, and each ps gets at most one 434 shard. 435 436 * `variable_partitioner` will be called for each variable created under 437 strategy `scope` to instruct how the variable should be partitioned. 438 Variables that have only one partition along the partitioning axis 439 (i.e., no need for partition) will be created as a normal `tf.Variable`. 440 441 * Only the first / outermost axis partitioning is supported. 442 443 * Div partition strategy is used to partition variables. Assuming we 444 assign consecutive integer ids along the first axis of a variable, then 445 ids are assigned to shards in a contiguous manner, while attempting to 446 keep each shard size identical. If the ids do not evenly divide the 447 number of shards, each of the first several shards will be assigned one 448 more id. For instance, a variable whose first dimension is 13 has 13 449 ids, and they are split across 5 shards as: 450 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 451 452 * Variables created under `strategy.extended.colocate_vars_with` will 453 not be partitioned. 454 """ 455 # pyformat: enable 456 self._cluster_resolver = cluster_resolver 457 458 self._verify_args_and_config(cluster_resolver) 459 self._cluster_coordinator = None 460 logging.info( 461 "`tf.distribute.experimental.ParameterServerStrategy` is initialized " 462 "with cluster_spec: %s", cluster_resolver.cluster_spec()) 463 464 # TODO(b/167894802): Make coordinator, worker, and ps names customizable. 465 self._connect_to_cluster(coordinator_name="chief") 466 self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver, 467 variable_partitioner) 468 super(ParameterServerStrategyV2, self).__init__(self._extended) 469 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 470 "ParameterServerStrategy") 471 self._should_use_with_coordinator = True 472 # Used while constructing distributed iterators. 473 self._canonicalize_devices = False 474 475 def _connect_to_cluster(self, coordinator_name): 476 if coordinator_name in ["worker", "ps"]: 477 raise ValueError("coordinator name should not be 'worker' or 'ps'.") 478 cluster_spec = self._cluster_resolver.cluster_spec() 479 self._num_workers = len(cluster_spec.as_dict().get("worker", ())) 480 self._num_ps = len(cluster_spec.as_dict().get("ps", ())) 481 482 device_filters = server_lib.ClusterDeviceFilters() 483 # For any worker, only the devices on ps and coordinator nodes are visible 484 for i in range(self._num_workers): 485 device_filters.set_device_filters( 486 "worker", i, ["/job:ps", "/job:%s" % coordinator_name]) 487 # Similarly for any ps, only the devices on workers and coordinator are 488 # visible 489 for i in range(self._num_ps): 490 device_filters.set_device_filters( 491 "ps", i, ["/job:worker", "/job:%s" % coordinator_name]) 492 493 # Allow at most one outstanding RPC for each worker at a certain time. This 494 # is to simplify worker failure handling in the runtime 495 os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False" 496 497 logging.info("%s is now connecting to cluster with cluster_spec: %r", 498 self.__class__.__name__, cluster_spec) 499 remote.connect_to_cluster( 500 cluster_spec, 501 job_name=coordinator_name, 502 protocol=self._cluster_resolver.rpc_layer, 503 cluster_device_filters=device_filters) 504 505 distribute_lib.distribution_strategy_replica_gauge.get_cell( 506 "ps_strategy_num_workers").set(self._num_workers) 507 distribute_lib.distribution_strategy_replica_gauge.get_cell( 508 "ps_strategy_num_ps").set(self._num_ps) 509 510 def _verify_args_and_config(self, cluster_resolver): 511 if not cluster_resolver.cluster_spec(): 512 raise ValueError("Cluster spec must be non-empty in " 513 "`tf.distribute.cluster_resolver.ClusterResolver`.") 514 cluster_spec = cluster_resolver.cluster_spec() 515 516 # The following checks if the task types are allowed (chief, ps, worker). 517 multi_worker_util._validate_cluster_spec( # pylint: disable=protected-access 518 cluster_spec, 519 cluster_resolver.task_type, 520 cluster_resolver.task_id) 521 522 if multi_worker_util.task_count(cluster_spec, "ps") < 1: 523 raise ValueError("There must be at least one ps.") 524 525 if multi_worker_util.task_count(cluster_spec, "worker") < 1: 526 raise ValueError("There must be at least one worker.") 527 528 529class ParameterServerStrategyV2Extended( 530 parameter_server_strategy.ParameterServerStrategyExtended): 531 """Extended class for ParameterServerStrategyV2. 532 533 Please see `tf.distribute.StrategyExtended` doc for more information. 534 """ 535 536 def __init__(self, container_strategy, cluster_resolver, 537 variable_partitioner): 538 """Initialization of ParameterServerStrategyV2Extended.""" 539 super(ParameterServerStrategyV2Extended, self).__init__(container_strategy) 540 self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", [])) 541 self._num_workers = len(cluster_resolver.cluster_spec().as_dict().get( 542 "worker", [])) 543 self._variable_count = 0 544 545 self._variable_partitioner = variable_partitioner 546 # The following two attrs are to verify that `ParameterServerStrategy` 547 # methods are properly used with a `ClusterCoordinator`. 548 self._used_with_coordinator = False 549 self._being_scheduled = False 550 self._set_num_gpus() 551 distribute_lib.distribution_strategy_replica_gauge.get_cell( 552 "num_gpus_per_worker").set(self._num_gpus_per_worker) 553 554 # Don't canonicalize the devices here since this code is executed on Chief, 555 # but we want the reduce evaluation to be done on each worker. Placer will 556 # automatically choose the right device based on current context. 557 # TODO(ishark): Use select_cross_device_ops instead. 558 self._cross_device_ops = cross_device_ops_lib.ReductionToOneDevice( 559 reduce_to_device="/device:CPU:0") 560 self._cross_device_ops._canonicalize_devices = False # pylint: disable=protected-access 561 self._allow_run_without_coordinator = False 562 563 def _set_num_gpus(self): 564 devices = config.list_logical_devices("GPU") 565 per_worker_gpus = {} 566 for d in devices: 567 d_spec = tf_device.DeviceSpec.from_string(d.name) 568 if d_spec.device_type == "GPU" and d_spec.job == "worker": 569 # TODO(b/167894802): update if worker name is customizable 570 job_spec = d_spec.replace(device_type=None, device_index=None) 571 per_worker_gpus[job_spec] = per_worker_gpus.get(job_spec, 0) + 1 572 573 num_gpus = 0 574 for _, count in per_worker_gpus.items(): 575 if num_gpus > 0 and count != num_gpus: 576 raise ValueError("Mismatched number of GPUs per worker") 577 num_gpus = count 578 579 self._num_gpus_per_worker = num_gpus 580 logging.info(f"Number of GPUs on workers: {self._num_gpus_per_worker}") 581 582 @property 583 def _num_replicas_in_sync(self): 584 return self._num_gpus_per_worker or 1 585 586 def _create_var_creator(self, next_creator, **kwargs): 587 aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) 588 589 def var_creator(**kwargs): 590 """Create an AggregatingVariable.""" 591 # Create and wrap the variable. 592 v = next_creator(**kwargs) 593 wrapped_v = ps_values.CachingVariable(v) 594 wrapped = ps_values.AggregatingVariable(self._container_strategy(), 595 wrapped_v, aggregation) 596 return wrapped 597 598 if self._num_replicas_in_sync > 1: 599 if aggregation not in ( 600 vs.VariableAggregation.NONE, 601 vs.VariableAggregation.SUM, 602 vs.VariableAggregation.MEAN, 603 vs.VariableAggregation.ONLY_FIRST_REPLICA 604 ): 605 raise ValueError("Invalid variable aggregation mode: " + aggregation + 606 " for variable: " + kwargs["name"]) 607 return var_creator 608 else: 609 def variable_creator_single_replica(**kwargs): 610 v = next_creator(**kwargs) 611 return ps_values.CachingVariable(v) 612 return variable_creator_single_replica 613 614 def _create_variable(self, next_creator, **kwargs): 615 """Implements StrategyExtendedV2._create_variable. 616 617 Creates a `Variable` or a `ShardedVariable`. A `ShardedVariable` will be 618 created if satisfying all the following criteria: 619 1. `self._variable_partitioner` results in more than one partition on the 620 first axis. 621 2. variable's rank is greater than 0. 622 3. variable is not colocated with another variable. 623 Otherwise a `Variable` will be created. 624 625 Args: 626 next_creator: See `variable_scope.variable_creator_scope`; the next 627 creator in the chain. 628 **kwargs: Passed through to the next creator. 629 630 Returns: 631 A `Variable` or `ShardedVariable`. 632 """ 633 634 var_creator = self._create_var_creator(next_creator, **kwargs) 635 if "colocate_with" in kwargs: # Never partition colocated_with variables. 636 colocate_with = kwargs["colocate_with"] 637 # Clear the variable scope to avoid possible conflicts between device 638 # scope and colocation scope. 639 with ops.device(None): 640 with ops.colocate_with(colocate_with): 641 var = var_creator(**kwargs) 642 logging.debug( 643 "Creating variable (name:%s, shape:%r) that colocates with %s", 644 var.name, var.shape, kwargs["colocate_with"].name) 645 return var 646 647 if self._variable_partitioner is None: 648 return self._create_variable_round_robin(var_creator, **kwargs) 649 650 name = kwargs.get("name", None) 651 initial_value = kwargs.get("initial_value", None) 652 if initial_value is None: 653 raise ValueError( 654 "It looks like you are using `ParameterServerStrategy` with a " 655 "`variable_partitioner`, and trying to create a variable without " 656 "specifying `initial_value`. This is not allowed. Please specify the " 657 "`initial_value`. This can also happen if you are trying to load a " 658 "saved_model within a `ParameterServerStrategy` scope. Loading a " 659 "saved_model with `variable_partitioner` is not supported.") 660 661 # Two cases where initial_value can be a callable: 662 # 1. initial_value is passed as a callable, e.g, an `initializer` class. 663 # 2. restoring from checkpoint, initial_value is a 664 # "CheckpointInitialValueCallable". 665 init_from_fn = callable(initial_value) 666 667 dtype = kwargs.get("dtype", None) 668 shape = kwargs.get("shape", None) 669 if init_from_fn and (shape is None or dtype is None): 670 init_from_fn = False 671 initial_value = initial_value() 672 if not init_from_fn: 673 # The initial_value is created on coordinator, it will need to be sent to 674 # ps for variable initialization, which can be inefficient and can 675 # potentially hit the 2GB limit on protobuf serialization. 676 initial_value = ops.convert_to_tensor(initial_value, dtype=dtype) 677 dtype = initial_value.dtype 678 shape = initial_value.shape 679 else: 680 shape = tensor_shape.as_shape(shape) 681 682 if shape.rank == 0: # Skip partitioning rank-0 variable. 683 return self._create_variable_round_robin(var_creator, **kwargs) 684 685 num_partitions = self._variable_partitioner(shape=shape, dtype=dtype) 686 if not num_partitions or num_partitions[0] == 0 or any( 687 v != 1 for v in num_partitions[1:]): 688 raise ValueError( 689 "variable_partitioner must return a list/tuple whose elements are 1" 690 " besides the first element (non-zero), got: %r" % num_partitions) 691 692 if num_partitions[0] == 1: # no partition 693 return self._create_variable_round_robin(var_creator, **kwargs) 694 695 # Use "div" partition strategy to partition the variable. 696 num_partitions = min(num_partitions[0], shape[0]) 697 base = shape[0] // num_partitions 698 extra = shape[0] % num_partitions 699 # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2] 700 # offsets: [0, 3, 6, 8, 10] 701 offsets = [] 702 for i in range(num_partitions): 703 if i == 0: 704 offsets.append(0) 705 else: 706 prev_shard_size = base + (1 if i - 1 < extra else 0) 707 offsets.append(offsets[i - 1] + prev_shard_size) 708 offsets.append(shape[0]) 709 710 def init_shard_fn(shard_index): 711 if not init_from_fn: 712 logging.log_if( 713 logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and 714 shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) 715 return initial_value[offsets[shard_index]:offsets[shard_index + 1]] 716 partition_shape = (offsets[shard_index + 1] - 717 offsets[shard_index],) + shape[1:] 718 partition_offset = (offsets[shard_index],) + (0,) * len(shape[1:]) 719 arg_spec = tf_inspect.getfullargspec(initial_value) 720 if ("shard_info" not in arg_spec.args and 721 "shard_info" not in arg_spec.kwonlyargs): 722 try: 723 value = initial_value( 724 partition_shape=partition_shape, 725 partition_offset=partition_offset) 726 except (TypeError, ValueError): 727 # TypeError: Initializer doesn't accept kwargs 728 # ValueError: Initializer doesn't accept partition kwargs 729 # In both cases we go ahead creating the full value and then slice. 730 value = initial_value() 731 732 if value.shape == partition_shape: 733 # Initializer supports partition: value is the partition value. 734 return value 735 else: 736 # Initializer doesn't support partition: value is the full value 737 # and needs to be sliced to get the partition value. 738 logging.log_if( 739 logging.WARN, _INEFFICIENT_INIT_WARNING % name, 740 shard_index == 0 and 741 shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) 742 return value[offsets[shard_index]:offsets[shard_index + 1]] 743 else: 744 # For compatibility with `CheckpointInitialValueCallable`. 745 return initial_value( 746 shard_info=trackable.ShardInfo( 747 shape=tensor_shape.as_shape(partition_shape), 748 offset=partition_offset)) 749 750 var_list = [] 751 for i in range(num_partitions): 752 kwargs["shape"] = (offsets[i + 1] - offsets[i],) + shape[1:] 753 kwargs["initial_value"] = lambda: init_shard_fn(i) 754 if name is not None: 755 kwargs["name"] = "{}/part_{}".format(name, i) 756 var_list.append(self._create_variable_round_robin(var_creator, **kwargs)) 757 758 result = sharded_variable.ShardedVariable(var_list) 759 return result 760 761 def _create_variable_round_robin(self, next_creator, **kwargs): 762 # Clear the colocation scope to avoid possible conflicts between device 763 # scope and colocation scope. 764 with ops.colocate_with(None, ignore_existing=True): 765 # Explicitly set CPU:0 device for PS in case create variable is called 766 # inside replica_fn and worker has with GPU:0 scope. 767 with ops.device("/job:ps/task:%d/device:CPU:0" % 768 (self._variable_count % self._num_ps)): 769 var = next_creator(**kwargs) 770 logging.debug( 771 "Creating variable (name:%s, shape:%r) on " 772 "/job:ps/task:%d/device:CPU:0", 773 var.name, var.shape, (self._variable_count % self._num_ps)) 774 self._variable_count += 1 775 return var 776 777 def _assert_used_with_cluster_coordinator(self): 778 if (not self._used_with_coordinator and 779 not self._allow_run_without_coordinator): 780 raise NotImplementedError( 781 "`tf.distribute.experimental.ParameterServerStrategy` must be used " 782 "with `tf.distribute.experimental.coordinator.ClusterCoordinator` in " 783 "a custom training loop. If you are using `Model.fit`, please supply " 784 "a dataset function directly to a " 785 "`tf.keras.utils.experimental.DatasetCreator` instead.") 786 787 def _assert_being_scheduled_by_cluster_coordinator(self): 788 if not self._being_scheduled and not self._allow_run_without_coordinator: 789 logging.warning( 790 "It is detected that a function used with " 791 "`tf.distribute.experimental.ParameterServerStrategy` " 792 "is executed locally on the coordinator. This is inefficient but may " 793 "be valid for one-off tasks such as inferring output signature. " 794 "To properly distribute functions to run on workers, `run` or " 795 "`reduce` should be used within a function passed to `" 796 "tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`." 797 ) 798 799 # options is not used right now. But we may want to support options while 800 # creating InputWorkers in future, similar to MirroredStrategy. 801 def _input_workers_with_options(self, options=None): 802 input_workers_devices = ( 803 ("/device:CPU:0", self.worker_devices),) 804 return input_lib.InputWorkers( 805 input_workers_devices, canonicalize_devices=False) 806 807 def _experimental_distribute_dataset(self, dataset, options): 808 input_workers_devices = self._input_workers_with_options() 809 810 # If this DistributedDataset is created outside ClusterCoordinator, i,e, 811 # outside a tf.function, we don't build its underlying datasets immediately 812 # until it is passed to ClusterCoordinator.create_per_worker_dataset. 813 return input_lib.get_distributed_dataset( 814 dataset, 815 input_workers_devices, 816 self._container_strategy(), 817 num_replicas_in_sync=self._num_replicas_in_sync, 818 options=options, 819 build=ops.inside_function()) # will be built by ClusterCoordinator 820 821 def _distribute_datasets_from_function(self, dataset_fn, options): 822 # There is no synchronization beyond a worker and thus, the number of 823 # input pipelines in sync is only 1 per worker. 824 input_pipeline_id_in_sync = 0 825 num_input_pipelines_in_sync = 1 826 827 input_context = distribute_lib.InputContext( 828 num_input_pipelines=num_input_pipelines_in_sync, 829 input_pipeline_id=input_pipeline_id_in_sync, 830 num_replicas_in_sync=self._num_replicas_in_sync) 831 832 # If this DistributedDatasetFromFunction is created outside 833 # ClusterCoordinator, i,e, outside a tf.function, we don't build its 834 # underlying datasets immediately until it is passed to 835 # ClusterCoordinator.create_per_worker_dataset. 836 return input_lib.get_distributed_datasets_from_function( 837 dataset_fn, 838 self._input_workers_with_options(options), 839 [input_context], 840 self._container_strategy(), 841 options=options, 842 build=ops.inside_function()) # will be built by ClusterCoordinator 843 844 @property 845 def worker_devices(self): 846 num_gpus = self._num_gpus_per_worker 847 if num_gpus > 0: 848 compute_devices = tuple("/device:GPU:%d" % (i,) for i in range(num_gpus)) 849 else: 850 compute_devices = ("/device:CPU:0",) 851 return compute_devices 852 853 def _call_for_each_replica(self, fn, args, kwargs): 854 self._assert_being_scheduled_by_cluster_coordinator() 855 856 return mirrored_run.call_for_each_replica(self._container_strategy(), fn, 857 args, kwargs) 858 859 def _reduce(self, reduce_op, value): 860 self._assert_being_scheduled_by_cluster_coordinator() 861 dst = device_util.current() or self._default_device or "/device:CPU:0" 862 destinations = device_util.canonicalize_without_job_and_task(dst) 863 result = self._local_results( 864 self.reduce_to(reduce_op, value, destinations))[0] 865 return result 866 867 def _reduce_to(self, reduce_op, value, destinations, options): 868 self._assert_being_scheduled_by_cluster_coordinator() 869 870 def get_values(x): 871 if isinstance(x, values.DistributedValues): 872 return self._cross_device_ops.reduce( 873 reduce_op, x, destinations=destinations) # pylint: disable=protected-access 874 return x 875 876 return nest.map_structure(get_values, value) 877 878 879# The warning that will be logged if the way we initialize sharded variables 880# is memory-inefficient. 881_INEFFICIENT_INIT_WARNING = ( 882 "Large variable %s is partitioned but not initialized in a " 883 "memory-efficient way. On each shard, the full value is first being " 884 "created and then sliced into smaller values. To reduce the memory " 885 "footprint, explicitly specify `dtype` and `shape` when creating " 886 "variables, and use `tf.initializers` to initialize the variable. " 887 "Note that some initializers (e.g., orthogonal) don't support " 888 "memory-efficient initialization and there is not much you can do here.") 889 890_LARGE_VARIABLE_NUM_ELEMENTS = 1e9 891