1# Lint as: python3 2# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Parameter server strategy V2 class. 17 18This is currently under development and the API is subject to change. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import os 26 27from tensorflow.python.distribute import distribute_lib 28from tensorflow.python.distribute import distribute_utils 29from tensorflow.python.distribute import multi_worker_util 30from tensorflow.python.distribute import parameter_server_strategy 31from tensorflow.python.distribute import sharded_variable 32from tensorflow.python.eager import remote 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import tensor_shape 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.training import server_lib 39from tensorflow.python.training.tracking import base as trackable 40from tensorflow.python.util import tf_inspect 41from tensorflow.python.util.tf_export import tf_export 42 43ALLOWED_TASK_TYPES = ("chief", "worker", "ps") 44 45 46@tf_export("distribute.experimental.ParameterServerStrategy", v1=[]) 47class ParameterServerStrategyV2(distribute_lib.Strategy): 48 """An multi-worker tf.distribute strategy with parameter servers. 49 50 Parameter server training is a common data-parallel method to scale up a 51 machine learning model on multiple machines. A parameter server training 52 cluster consists of workers and parameter servers. Variables are created on 53 parameter servers and they are read and updated by workers in each step. 54 By default, workers read and update these variables independently without 55 synchronizing with each other. Under this configuration, it is known as 56 asynchronous training. 57 58 In TensorFlow 2, we recommend an architecture based on central coordination 59 for parameter server training. Each worker and parameter server runs a 60 `tf.distribute.Server`, and on top of that, a coordinator task is responsible 61 for creating resources on workers and parameter servers, dispatching 62 functions, and coordinating the training. The coordinator uses a 63 `tf.distribute.experimental.coordinator.ClusterCoordinator` to coordinate the 64 cluster, and a `tf.distribute.experimental.ParameterServerStrategy` to define 65 variables on parameter servers and computation on workers. 66 67 For the training to work, the coordinator dispatches `tf.function`s to be 68 executed on remote workers. Upon receiving requests from the coordinator, a 69 worker executes the `tf.function` by reading the variables from parameter 70 servers, executing the ops, and updating the variables on the parameter 71 servers. Each of the worker only processes the requests from the coordinator, 72 and communicates with parameter servers, without direct interactions with 73 other workers in the cluster. 74 75 As a result, failures of some workers do not prevent the cluster from 76 continuing the work, and this allows the cluster to train with instances that 77 can be occasionally unavailable (e.g. preemptible or spot instances). The 78 coordinator and parameter servers though, must be available at all times for 79 the cluster to make progress. 80 81 Note that the coordinator is not one of the training workers. Instead, it 82 creates resources such as variables and datasets, dispatchs `tf.function`s, 83 saves checkpoints and so on. In addition to workers, parameter servers and 84 the coordinator, an optional evaluator can be run on the side that 85 periodically reads the checkpoints saved by the coordinator and runs 86 evaluations against each checkpoint. 87 88 `tf.distribute.experimental.ParameterServerStrategy` has to work in 89 conjunction with a `tf.distribute.experimental.coordinator.ClusterCoordinator` 90 object. Standalone usage of 91 `tf.distribute.experimental.ParameterServerStrategy` without central 92 coordination is not supported at this time. 93 94 __Example code for coordinator__ 95 96 Here's an example usage of the API, with a custom training loop to train a 97 model. This code snippet is intended to be run on (the only) one task that 98 is designated as the coordinator. Note that `cluster_resolver`, 99 `variable_partitioner`, and `dataset_fn` arguments are explained in the 100 following "Cluster setup", "Variable partitioning", and "Dataset preparation" 101 sections. 102 103 ```python 104 # Set the environment variable to allow reporting worker and ps failure to the 105 # coordinator. This a short-term workaround. 106 os.environ["GRPC_FAIL_FAST"] = "use_caller" 107 108 # Prepare a strategy to use with the cluster and variable partitioning info. 109 strategy = tf.distribute.experimental.ParameterServerStrategy( 110 cluster_resolver=..., 111 variable_partitioner=...) 112 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 113 strategy=strategy) 114 115 # Prepare a distribute dataset that will place datasets on the workers. 116 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...) 117 118 with strategy.scope(): 119 model = ... 120 optimizer, metrics = ... # Keras optimizer/metrics are great choices 121 checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) 122 checkpoint_manager = tf.train.CheckpointManager( 123 checkpoint, checkpoint_dir, max_to_keep=2) 124 # `load_checkpoint` infers initial epoch from `optimizer.iterations`. 125 initial_epoch = load_checkpoint(checkpoint_manager) or 0 126 127 @tf.function 128 def worker_fn(iterator): 129 130 def replica_fn(inputs): 131 batch_data, labels = inputs 132 # calculate gradient, applying gradient, metrics update etc. 133 134 strategy.run(replica_fn, args=(next(iterator),)) 135 136 for epoch in range(initial_epoch, num_epoch): 137 distributed_iterator = iter(distributed_dataset) # Reset iterator state. 138 for step in range(steps_per_epoch): 139 140 # Asynchronously schedule the `worker_fn` to be executed on an arbitrary 141 # worker. This call returns immediately. 142 coordinator.schedule(worker_fn, args=(distributed_iterator,)) 143 144 # `join` blocks until all scheduled `worker_fn`s finish execution. Once it 145 # returns, we can read the metrics and save checkpoints as needed. 146 coordinator.join() 147 logging.info('Metric result: %r', metrics.result()) 148 train_accuracy.reset_states() 149 checkpoint_manager.save() 150 ``` 151 152 __Example code for worker and parameter servers__ 153 154 In addition to the coordinator, there should be tasks designated as 155 "worker" or "ps". They should run the following code to start a TensorFlow 156 server, waiting for coordinator's requests: 157 158 ```python 159 # Set the environment variable to allow reporting worker and ps failure to the 160 # coordinator. 161 os.environ["GRPC_FAIL_FAST"] = "use_caller" 162 163 # Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves 164 # the cluster information. See below "Cluster setup" section. 165 cluster_resolver = ... 166 167 server = tf.distribute.Server( 168 cluster_resolver.cluster_spec(), 169 job_name=cluster_resolver.task_type, 170 task_index=cluster_resolver.task_id, 171 protocol="grpc") 172 173 # Blocking the process that starts a server from exiting. 174 server.join() 175 ``` 176 177 __Cluster setup__ 178 179 In order for the tasks in the cluster to know other tasks' addresses, 180 a `tf.distribute.cluster_resolver.ClusterResolver` is required to be used 181 in coordinator, worker, and ps. The 182 `tf.distribute.cluster_resolver.ClusterResolver` is responsible for providing 183 the cluster information, as well as the task type and id of the current task. 184 See `tf.distribute.cluster_resolver.ClusterResolver` for more information. 185 186 If `TF_CONFIG` environment variable is set, a 187 `tf.distribute.cluster_resolver.TFConfigClusterResolver` should be used as 188 well. 189 190 Since there are assumptions in 191 `tf.distribute.experimental.ParameterServerStrategy` around the naming of the 192 task types, "chief", "ps", and "worker" should be used in the 193 `tf.distribute.cluster_resolver.ClusterResolver` to refer to the coordinator, 194 parameter servers, and workers, respectively. 195 196 The following example demonstrates setting `TF_CONFIG` for the task designated 197 as a parameter server (task type "ps") and index 1 (the second task), in a 198 cluster with 1 chief, 2 parameter servers, and 3 workers. Note that it needs 199 to be set before the use of 200 `tf.distribute.cluster_resolver.TFConfigClusterResolver`. 201 202 Example code for cluster setup: 203 ```python 204 os.environ['TF_CONFIG'] = ''' 205 { 206 "cluster": { 207 "chief": ["chief.example.com:2222"], 208 "ps": ["ps0.example.com:2222", "ps1.example.com:2222"], 209 "worker": ["worker0.example.com:2222", "worker1.example.com:2222", 210 "worker2.example.com:2222"] 211 }, 212 "task": { 213 "type": "ps", 214 "index": 1 215 } 216 } 217 ''' 218 ``` 219 220 If you prefer to run the same binary for all tasks, you will need to let the 221 binary branch into different roles at the beginning of the program: 222 ```python 223 os.environ["GRPC_FAIL_FAST"] = "use_caller" 224 cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() 225 226 # If coordinator, create a strategy and start the training program. 227 if cluster_resolver.task_type == 'chief': 228 strategy = tf.distribute.experimental.ParameterServerStrategy( 229 cluster_resolver) 230 ... 231 232 # If worker/ps, create a server 233 elif cluster_resolver.task_type in ("worker", "ps"): 234 server = tf.distribute.Server(...) 235 ... 236 ``` 237 Alternatively, you can also start a bunch of TensorFlow servers in advance and 238 connect to them later. The coordinator can be in the same cluster or on any 239 machine that has connectivity to workers and parameter servers. This is 240 covered in our guide and tutorial. 241 242 __Variable creation with `strategy.scope()`__ 243 244 `tf.distribute.experimental.ParameterServerStrategy` follows the 245 `tf.distribute` API contract where variable creation is expected to be inside 246 the context manager returned by `strategy.scope()`, in order to be correctly 247 placed on parameter servers in a round-robin manner: 248 249 ```python 250 # In this example, we're assuming having 3 ps. 251 strategy = tf.distribute.experimental.ParameterServerStrategy( 252 cluster_resolver=...) 253 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 254 strategy=strategy) 255 256 # Variables should be created inside scope to be placed on parameter servers. 257 # If created outside scope such as `v1` here, it would be placed on the 258 # coordinator. 259 v1 = tf.Variable(initial_value=0.0) 260 261 with strategy.scope(): 262 v2 = tf.Variable(initial_value=1.0) 263 v3 = tf.Variable(initial_value=2.0) 264 v4 = tf.Variable(initial_value=3.0) 265 v5 = tf.Variable(initial_value=4.0) 266 267 # v2 through v5 are created in scope and are distributed on parameter servers. 268 # Default placement is round-robin but the order should not be relied on. 269 assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0" 270 assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0" 271 assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0" 272 assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0" 273 ``` 274 275 See `distribute.Strategy.scope` for more information. 276 277 __Variable partitioning__ 278 279 Having dedicated servers to store variables means being able to divide up, or 280 "shard" the variables across the ps. Partitioning large variable among ps is a 281 commonly used technique to boost training throughput and mitigate memory 282 constraints. It enables parallel computations and updates on different shards 283 of a variable, and often yields better load balancing across parameter 284 servers. Without sharding, models with large variables (e.g, embeddings) that 285 can't fit into one machine's memory would otherwise be unable to train. 286 287 With `tf.distribute.experimental.ParameterServerStrategy`, if a 288 `variable_partitioner` is provided to `__init__` and certain conditions are 289 satisfied, the resulting variables created in scope are sharded across the 290 parameter servers, in a round-robin fashion. The variable reference returned 291 from `tf.Variable` becomes a type that serves as the container of the sharded 292 variables. One can access `variables` attribute of this container for the 293 actual variable components. If building model with `tf.Module` or Keras, 294 the variable components are collected in the `variables` alike attributes. 295 296 297 ```python 298 class Dense(tf.Module): 299 def __init__(self, name=None): 300 super().__init__(name=name) 301 self.w = tf.Variable(tf.random.normal([100, 10]), name='w') 302 303 def __call__(self, x): 304 return x * self.w 305 306 # Partition the dense layer into 2 shards. 307 variable_partitioner = ( 308 tf.distribute.experimental.partitioners.FixedShardsPartitioner( 309 num_shards = 2)) 310 strategy = tf.distribute.experimental.ParameterServerStrategy( 311 cluster_resolver=..., 312 variable_partitioner = variable_partitioner) 313 with strategy.scope(): 314 dense = Dense() 315 assert len(dense.variables) == 2 316 assert isinstance(dense.variables[0], tf.Variable) 317 assert isinstance(dense.variables[1], tf.Variable) 318 assert dense.variables[0].shape == (50, 10) 319 assert dense.variables[1].shape == (50, 10) 320 ``` 321 322 The sharded variable container can be converted to a `Tensor` via 323 `tf.convert_to_tensor`. This means the container can be directly used in most 324 Python Ops where such `Tensor` conversion automatically happens. For example, 325 in the above code snippet, `x * self.w` would implicitly apply the said tensor 326 conversion. Note that such conversion can be expensive, as the variable 327 components need to be transferred from multiple parameter servers to where 328 the value is used. 329 330 `tf.nn.embedding_lookup` on the other hand doesn't apply the tensor 331 conversion, and performs parallel lookups on the variable components instead. 332 This is crucial to scale up embedding lookups when the embedding table 333 variable is large. 334 335 When a partitioned variable is saved to a `SavedModel`, it will be saved as if 336 it is one single variable. This improves serving efficiency by eliminating 337 a number of Ops that handle the partiton aspects. 338 339 Known limitations of variable partitioning: 340 341 * Number of partitions must not change across Checkpoint saving/loading. 342 343 * After saving partitioned variables to a SavedModel, the SavedModel can't be 344 loaded via `tf.saved_model.load`. 345 346 * Partition variable doesn't directly work with `tf.GradientTape`, please use 347 the `variables` attributes to get the actual variable components and use 348 them in gradient APIs instead. 349 350 __Dataset preparation__ 351 352 With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is 353 created in each of the workers to be used for training. This is done by 354 creating a `dataset_fn` that takes no argument and returns a 355 `tf.data.Dataset`, and passing the `dataset_fn` into 356 `tf.distribute.experimental.coordinator. 357 ClusterCoordinator.create_per_worker_dataset`. We recommend the dataset to be 358 shuffled and repeated to have the examples run through the training as evenly 359 as possible. 360 361 ```python 362 def dataset_fn(): 363 filenames = ... 364 dataset = tf.data.Dataset.from_tensor_slices(filenames) 365 366 # Dataset is recommended to be shuffled, and repeated. 367 return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...) 368 369 coordinator = 370 tf.distribute.experimental.coordinator.ClusterCoordinator(strategy=...) 371 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn) 372 ``` 373 374 __Limitations__ 375 376 * `tf.distribute.experimental.ParameterServerStrategy` in TF2 is experimental, 377 and the API is subject to further changes. 378 379 * `tf.distribute.experimental.ParameterServerStrategy` does not yet support 380 training with GPU(s). This is a feature request being developed. 381 382 * `tf.distribute.experimental.ParameterServerStrategy` only supports 383 [custom training loop 384 API](https://www.tensorflow.org/tutorials/distribute/custom_training) 385 currently in TF2. Usage of it with Keras `compile`/`fit` API is being 386 developed. 387 388 * `tf.distribute.experimental.ParameterServerStrategy` must be used with 389 `tf.distribute.experimental.coordinator.ClusterCoordinator`. 390 """ 391 392 # pyformat: disable 393 def __init__(self, cluster_resolver, variable_partitioner=None): 394 """Initializes the TF2 parameter server strategy. 395 396 This initializes the `tf.distribute.experimental.ParameterServerStrategy` 397 object to be ready for use with 398 `tf.distribute.experimental.coordinator.ClusterCoordinator`. 399 400 Args: 401 cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver` 402 object. 403 variable_partitioner: 404 a `distribute.experimental.partitioners.Partitioner` that specifies 405 how to partition variables. If `None`, variables will not be 406 partitioned. 407 408 * Predefined partitioners in `tf.distribute.experimental.partitioners` 409 can be used for this argument. A commonly used partitioner is 410 `MinSizePartitioner(min_shard_bytes = 256 << 10, max_shards = num_ps)`, 411 which allocates at least 256K per shard, and each ps gets at most one 412 shard. 413 414 * `variable_partitioner` will be called for each variable created under 415 strategy `scope` to instruct how the variable should be partitioned. 416 Variables that have only one partition along the partitioning axis 417 (i.e., no need for partition) will be created as a normal `tf.Variable`. 418 419 * Only the first / outermost axis partitioning is supported. 420 421 * Div partition strategy is used to partition variables. Assuming we 422 assign consecutive integer ids along the first axis of a variable, then 423 ids are assigned to shards in a contiguous manner, while attempting to 424 keep each shard size identical. If the ids do not evenly divide the 425 number of shards, each of the first several shards will be assigned one 426 more id. For instance, a variable whose first dimension is 13 has 13 427 ids, and they are split across 5 shards as: 428 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 429 430 * Variables created under `strategy.extended.colocate_vars_with` will 431 not be partitioned. 432 """ 433 # pyformat: enable 434 self._cluster_resolver = cluster_resolver 435 self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver, 436 variable_partitioner) 437 self._verify_args_and_config(cluster_resolver) 438 self._cluster_coordinator = None 439 logging.info( 440 "`tf.distribute.experimental.ParameterServerStrategy` is initialized " 441 "with cluster_spec: %s", cluster_resolver.cluster_spec()) 442 443 # TODO(b/167894802): Make coordinator, worker, and ps names customizable. 444 self._connect_to_cluster(coordinator_name="chief") 445 super(ParameterServerStrategyV2, self).__init__(self._extended) 446 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 447 "ParameterServerStrategy") 448 self._should_use_with_coordinator = True 449 450 def _connect_to_cluster(self, coordinator_name): 451 if coordinator_name in ["worker", "ps"]: 452 raise ValueError("coordinator name should not be 'worker' or 'ps'.") 453 cluster_spec = self._cluster_resolver.cluster_spec() 454 self._num_workers = len(cluster_spec.as_dict().get("worker", ())) 455 self._num_ps = len(cluster_spec.as_dict().get("ps", ())) 456 457 device_filters = server_lib.ClusterDeviceFilters() 458 # For any worker, only the devices on ps and coordinator nodes are visible 459 for i in range(self._num_workers): 460 device_filters.set_device_filters( 461 "worker", i, ["/job:ps", "/job:%s" % coordinator_name]) 462 # Similarly for any ps, only the devices on workers and coordinator are 463 # visible 464 for i in range(self._num_ps): 465 device_filters.set_device_filters( 466 "ps", i, ["/job:worker", "/job:%s" % coordinator_name]) 467 468 # Allow at most one outstanding RPC for each worker at a certain time. This 469 # is to simplify worker failure handling in the runtime 470 os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False" 471 472 logging.info("%s is now connecting to cluster with cluster_spec: %r", 473 self.__class__.__name__, cluster_spec) 474 remote.connect_to_cluster( 475 cluster_spec, 476 job_name=coordinator_name, 477 protocol=self._cluster_resolver.rpc_layer, 478 cluster_device_filters=device_filters) 479 480 distribute_lib.distribution_strategy_replica_gauge.get_cell( 481 "ps_strategy_num_workers").set(self._num_workers) 482 distribute_lib.distribution_strategy_replica_gauge.get_cell( 483 "ps_strategy_num_ps").set(self._num_ps) 484 485 def _verify_args_and_config(self, cluster_resolver): 486 if not cluster_resolver.cluster_spec(): 487 raise ValueError("Cluster spec must be non-empty in " 488 "`tf.distribute.cluster_resolver.ClusterResolver`.") 489 if self.extended._num_gpus_per_worker > 1: # pylint: disable=protected-access 490 raise NotImplementedError("Multi-gpu is not supported yet.") 491 492 cluster_spec = cluster_resolver.cluster_spec() 493 494 # The following checks if the task types are allowed (chief, ps, worker). 495 multi_worker_util._validate_cluster_spec( # pylint: disable=protected-access 496 cluster_spec, 497 cluster_resolver.task_type, 498 cluster_resolver.task_id) 499 500 if multi_worker_util.task_count(cluster_spec, "ps") < 1: 501 raise ValueError("There must be at least one ps.") 502 503 if multi_worker_util.task_count(cluster_spec, "worker") < 1: 504 raise ValueError("There must be at least one worker.") 505 506 507class ParameterServerStrategyV2Extended( 508 parameter_server_strategy.ParameterServerStrategyExtended): 509 """Extended class for ParameterServerStrategyV2. 510 511 Please see `tf.distribute.StrategyExtended` doc for more information. 512 """ 513 514 def __init__(self, container_strategy, cluster_resolver, 515 variable_partitioner): 516 """Initialization of ParameterServerStrategyV2Extended.""" 517 super(ParameterServerStrategyV2Extended, self).__init__(container_strategy) 518 self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", [])) 519 self._variable_count = 0 520 self._variable_partitioner = variable_partitioner 521 522 # The following two attrs are to verify that `ParameterServerStrategy` 523 # methods are properly used with a `ClusterCoordinator`. 524 self._used_with_coordinator = False 525 self._being_scheduled = False 526 527 def _create_variable(self, next_creator, **kwargs): 528 """Implements StrategyExtendedV2._create_variable. 529 530 Creates a `Variable` or a `ShardedVariable`. A `ShardedVariable` will be 531 created if satisfying all the following criteria: 532 1. `self._variable_partitioner` results in more than one partition on the 533 first axis. 534 2. variable's rank is greater than 0. 535 3. variable is not colocated with another variable. 536 Otherwise a `Variable` will be created. 537 538 Args: 539 next_creator: See `variable_scope.variable_creator_scope`; the next 540 creator in the chain. 541 **kwargs: Passed through to the next creator. 542 543 Returns: 544 A `Variable` or `ShardedVariable`. 545 """ 546 547 if "colocate_with" in kwargs: # Never partition colocated_with variables. 548 colocate_with = kwargs["colocate_with"] 549 # Clear the variable scope to avoid possible conflicts between device 550 # scope and colocation scope. 551 with ops.device(None): 552 with ops.colocate_with(colocate_with): 553 var = next_creator(**kwargs) 554 logging.debug( 555 "Creating variable (name:%s, shape:%r) that colocates with %s", 556 var.name, var.shape, kwargs["colocate_with"].name) 557 return var 558 559 if self._variable_partitioner is None: 560 return self._create_variable_round_robin(next_creator, **kwargs) 561 562 name = kwargs.get("name", None) 563 initial_value = kwargs.get("initial_value", None) 564 if initial_value is None: 565 raise ValueError( 566 "It looks like you are using `ParameterServerStrategy` with a " 567 "`variable_partitioner`, and trying to create a variable without " 568 "specifying `initial_value`. This is not allowed. Please specify the " 569 "`initial_value`. This can also happen if you are trying to load a " 570 "saved_model within a `ParameterServerStrategy` scope. Loading a " 571 "saved_model with `variable_partitioner` is not supported.") 572 573 # Two cases where initial_value can be a callable: 574 # 1. initial_value is passed as a callable, e.g, an `initializer` class. 575 # 2. restoring from checkpoint, initial_value is a 576 # "CheckpointInitialValueCallable". 577 init_from_fn = callable(initial_value) 578 579 dtype = kwargs.get("dtype", None) 580 shape = kwargs.get("shape", None) 581 if init_from_fn and (shape is None or dtype is None): 582 init_from_fn = False 583 initial_value = initial_value() 584 if not init_from_fn: 585 # The initial_value is created on coordinator, it will need to be sent to 586 # ps for variable initialization, which can be inefficient and can 587 # potentially hit the 2GB limit on protobuf serialization. 588 initial_value = ops.convert_to_tensor(initial_value, dtype=dtype) 589 dtype = initial_value.dtype 590 shape = initial_value.shape 591 else: 592 shape = tensor_shape.as_shape(shape) 593 594 if shape.rank == 0: # Skip partitioning rank-0 variable. 595 return self._create_variable_round_robin(next_creator, **kwargs) 596 597 num_partitions = self._variable_partitioner(shape=shape, dtype=dtype) 598 if not num_partitions or num_partitions[0] == 0 or any( 599 v != 1 for v in num_partitions[1:]): 600 raise ValueError( 601 "variable_partitioner must return a list/tuple whose elements are 1" 602 " besides the first element (non-zero), got: %r" % num_partitions) 603 604 if num_partitions[0] == 1: # no partition 605 return self._create_variable_round_robin(next_creator, **kwargs) 606 607 # Use "div" partition strategy to partition the variable. 608 num_partitions = min(num_partitions[0], shape[0]) 609 base = shape[0] // num_partitions 610 extra = shape[0] % num_partitions 611 # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2] 612 # offsets: [0, 3, 6, 8, 10] 613 offsets = [] 614 for i in range(num_partitions): 615 if i == 0: 616 offsets.append(0) 617 else: 618 prev_shard_size = base + (1 if i - 1 < extra else 0) 619 offsets.append(offsets[i - 1] + prev_shard_size) 620 offsets.append(shape[0]) 621 622 def init_shard_fn(shard_index): 623 if not init_from_fn: 624 logging.log_if( 625 logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and 626 shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) 627 return initial_value[offsets[shard_index]:offsets[shard_index + 1]] 628 partition_shape = (offsets[shard_index + 1] - 629 offsets[shard_index],) + shape[1:] 630 partition_offset = (offsets[shard_index],) + (0,) * len(shape[1:]) 631 arg_spec = tf_inspect.getfullargspec(initial_value) 632 if ("shard_info" not in arg_spec.args and 633 "shard_info" not in arg_spec.kwonlyargs): 634 try: 635 value = initial_value( 636 partition_shape=partition_shape, 637 partition_offset=partition_offset) 638 except (TypeError, ValueError): 639 # TypeError: Initializer doesn't accept kwargs 640 # ValueError: Initializer doesn't accept partition kwargs 641 # In both cases we go ahead creating the full value and then slice. 642 value = initial_value() 643 644 if value.shape == partition_shape: 645 # Initializer supports partition: value is the partition value. 646 return value 647 else: 648 # Initializer doesn't support partition: value is the full value 649 # and needs to be sliced to get the partition value. 650 logging.log_if( 651 logging.WARN, _INEFFICIENT_INIT_WARNING % name, 652 shard_index == 0 and 653 shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) 654 return value[offsets[shard_index]:offsets[shard_index + 1]] 655 else: 656 # For compatibility with `CheckpointInitialValueCallable`. 657 return initial_value( 658 shard_info=trackable.ShardInfo( 659 shape=tensor_shape.as_shape(partition_shape), 660 offset=partition_offset)) 661 662 var_list = [] 663 for i in range(num_partitions): 664 kwargs["shape"] = (offsets[i + 1] - offsets[i],) + shape[1:] 665 kwargs["initial_value"] = lambda: init_shard_fn(i) 666 if name is not None: 667 kwargs["name"] = "{}/part_{}".format(name, i) 668 var_list.append(self._create_variable_round_robin(next_creator, **kwargs)) 669 670 result = sharded_variable.ShardedVariable(var_list) 671 return result 672 673 def _create_variable_round_robin(self, next_creator, **kwargs): 674 # Clear the colocation scope to avoid possible conflicts between device 675 # scope and colocation scope. 676 with ops.colocate_with(None, ignore_existing=True): 677 with ops.device("/job:ps/task:%d" % 678 (self._variable_count % self._num_ps)): 679 var = next_creator(**kwargs) 680 logging.debug( 681 "Creating variable (name:%s, shape:%r) on /job:ps/task:%d", 682 var.name, var.shape, (self._variable_count % self._num_ps)) 683 self._variable_count += 1 684 return var 685 686 def _assert_used_with_cluster_coordinator(self): 687 if not self._used_with_coordinator: 688 raise NotImplementedError( 689 "`tf.distribute.experimental.ParameterServerStrategy` must be used " 690 "with `tf.distribute.experimental.coordinator.ClusterCoordinator`.") 691 692 def _assert_being_scheduled_by_cluster_coordinator(self): 693 if not self._being_scheduled: 694 raise NotImplementedError( 695 "`tf.distribute.experimental.ParameterServerStrategy`'s `run` or " 696 "`reduce` must be used within a function passed to `" 697 "tf.distribute.experimental.coordinator.ClusterCoordinator.schedule" 698 "`.") 699 700 def _experimental_distribute_dataset(self, dataset, options): 701 self._assert_used_with_cluster_coordinator() 702 if not ops.get_default_graph().building_function: 703 raise ValueError( 704 "The `experimental_distribute_dataset` method must be called inside " 705 "a `tf.function` passed to `create_per_worker_dataset` of " 706 "`tf.distribute.experimental.coordinator.ClusterCoordinator`") 707 return dataset 708 709 def _distribute_datasets_from_function(self, dataset_fn, options): 710 self._assert_used_with_cluster_coordinator() 711 if not ops.get_default_graph().building_function: 712 raise ValueError( 713 "The `distribute_datasets_from_function` method must be called " 714 "inside a `tf.function` passed to `create_per_worker_dataset` of " 715 "`tf.distribute.experimental.coordinator.ClusterCoordinator`") 716 return dataset_fn(distribute_lib.InputContext()) 717 718 def _call_for_each_replica(self, fn, args, kwargs): 719 self._assert_being_scheduled_by_cluster_coordinator() 720 with distribute_lib.ReplicaContext( 721 self._container_strategy(), 722 replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)): 723 # TODO(rchao): Support multi-replica per worker or sync-group. 724 return distribute_utils.regroup((fn(*args, **kwargs),)) 725 726 def _reduce(self, reduce_op, value): 727 self._assert_being_scheduled_by_cluster_coordinator() 728 # TODO(rchao): Provide implementation for multi-replica. Also look into why 729 # the default implementation is not working. 730 return value 731 732 733# The warning that will be logged if the way we initialize sharded variables 734# is memory-inefficient. 735_INEFFICIENT_INIT_WARNING = ( 736 "Large variable %s is partitioned but not initialized in a " 737 "memory-efficient way. On each shard, the full value is first being " 738 "created and then sliced into smaller values. To reduce the memory " 739 "footprint, explicitly specify `dtype` and `shape` when creating " 740 "variables, and use `tf.initializers` to initialize the variable. " 741 "Note that some initializers (e.g., orthogonal) don't support " 742 "memory-efficient initialization and there is not much you can do here.") 743 744_LARGE_VARIABLE_NUM_ELEMENTS = 1e9 745