1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python API for executing a tf.data.Dataset using a tf.data service.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import enum 21import functools 22import six 23 24from tensorflow.core.protobuf import data_service_pb2 25from tensorflow.python import tf2 26from tensorflow.python.compat import compat 27from tensorflow.python.data.experimental.ops import compression_ops 28from tensorflow.python.data.experimental.service import _pywrap_server_lib 29from tensorflow.python.data.experimental.service import _pywrap_utils 30from tensorflow.python.data.ops import dataset_ops 31from tensorflow.python.data.ops import options as options_lib 32from tensorflow.python.data.ops.options import AutoShardPolicy 33from tensorflow.python.data.ops.options import ExternalStatePolicy 34from tensorflow.python.eager import context 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import tensor_spec 38from tensorflow.python.framework import tensor_util 39from tensorflow.python.ops import gen_experimental_dataset_ops 40from tensorflow.python.ops import string_ops 41from tensorflow.python.util import lazy_loader 42from tensorflow.python.util.tf_export import tf_export 43 44COMPRESSION_AUTO = "AUTO" 45COMPRESSION_NONE = None 46_PARALLEL_EPOCHS = "parallel_epochs" 47_DISTRIBUTED_EPOCH = "distributed_epoch" 48 49# TODO(b/176933539): Use the regular import. 50nested_structure_coder = lazy_loader.LazyLoader( 51 "nested_structure_coder", globals(), 52 "tensorflow.python.saved_model.nested_structure_coder") 53 54 55@tf_export("data.experimental.service.ShardingPolicy") 56class ShardingPolicy(enum.IntEnum): 57 """Specifies how to shard data among tf.data service workers. 58 59 OFF: No sharding will be performed. Each worker produces the entire dataset 60 without any sharding. With this mode, the best practice is to shuffle the 61 dataset nondeterministically so that workers process the dataset in different 62 orders. If workers are restarted or join the cluster mid-job, they will begin 63 processing the dataset from the beginning. 64 65 DYNAMIC: The input dataset is dynamically split among workers at runtime. Each 66 worker gets the next split when it reads data from the dispatcher. Data is 67 produced non-deterministically in this mode. Dynamic sharding works well with 68 varying-sized tf.data service clusters, e.g., when you need to auto-scale your 69 workers. Dynamic sharding provides at-most once visitation guarantees. No 70 examples will be repeated, but some may be missed if a tf.data service worker 71 gets restarted while processing a file. 72 73 The following are static sharding policies. The semantics are similar to 74 `tf.data.experimental.AutoShardPolicy`. These policies require: 75 * The tf.data service cluster is configured with a fixed list of workers 76 in DispatcherConfig. 77 * Each client only reads from the local tf.data service worker. 78 79 If a worker is restarted while performing static sharding, the worker will 80 begin processing its shard again from the beginning. 81 82 FILE: Shards by input files (i.e. each worker will get a fixed set of files to 83 process). When this option is selected, make sure that there is at least as 84 many files as workers. If there are fewer input files than workers, a runtime 85 error will be raised. 86 87 DATA: Shards by elements produced by the dataset. Each worker will process the 88 whole dataset and discard the portion that is not for itself. Note that for 89 this mode to correctly partition the dataset elements, the dataset needs to 90 produce elements in a deterministic order. 91 92 FILE_OR_DATA: Attempts FILE-based sharding, falling back to DATA-based 93 sharding on failure. 94 95 HINT: Looks for the presence of `shard(SHARD_HINT, ...)` which is treated as a 96 placeholder to replace with `shard(num_workers, worker_index)`. 97 """ 98 99 # LINT.IfChange(tf_data_service_sharding_policy) 100 OFF = 0 101 DYNAMIC = 1 102 FILE = 2 103 DATA = 3 104 FILE_OR_DATA = 4 105 HINT = 5 106 # LINT.ThenChange() 107 108 def _to_proto(self): 109 """Converts the policy to ProcessingModeDef proto enum.""" 110 111 if self == ShardingPolicy.OFF: 112 return data_service_pb2.ProcessingModeDef.OFF 113 if self == ShardingPolicy.DYNAMIC: 114 return data_service_pb2.ProcessingModeDef.DYNAMIC 115 if self == ShardingPolicy.FILE: 116 return data_service_pb2.ProcessingModeDef.FILE 117 if self == ShardingPolicy.DATA: 118 return data_service_pb2.ProcessingModeDef.DATA 119 if self == ShardingPolicy.FILE_OR_DATA: 120 return data_service_pb2.ProcessingModeDef.FILE_OR_DATA 121 if self == ShardingPolicy.HINT: 122 return data_service_pb2.ProcessingModeDef.HINT 123 raise ValueError( 124 f"Unable to convert sharding policy {self!r} to proto. Please verify " 125 "the policy mapping.") 126 127 128def _get_validated_sharding_policy(processing_mode): 129 """Validates `processing_mode` and converts it to ShardingPolicy.""" 130 131 if isinstance(processing_mode, ShardingPolicy): 132 return processing_mode 133 if compat.forward_compatible(2021, 8, 24): 134 if processing_mode == _PARALLEL_EPOCHS: 135 return ShardingPolicy.OFF 136 if processing_mode == _DISTRIBUTED_EPOCH: 137 return ShardingPolicy.DYNAMIC 138 elif processing_mode in [_PARALLEL_EPOCHS, _DISTRIBUTED_EPOCH]: 139 return processing_mode 140 141 raise ValueError( 142 "tf.data service processing mode should be a ShardingPolicy, " 143 "`\"parallel_epochs\"`, or `\"distributed_epoch\"`. Got " 144 f"{processing_mode!r}.") 145 146 147def _serialize(processing_mode): 148 """Serializes `processing_mode`.""" 149 150 processing_mode = _get_validated_sharding_policy(processing_mode) 151 if isinstance(processing_mode, ShardingPolicy): 152 # pylint: disable=protected-access 153 processing_mode_def = data_service_pb2.ProcessingModeDef( 154 sharding_policy=_get_validated_sharding_policy( 155 processing_mode)._to_proto()) 156 return processing_mode_def.SerializeToString() 157 if processing_mode in [_PARALLEL_EPOCHS, _DISTRIBUTED_EPOCH]: 158 return processing_mode 159 160 raise ValueError( 161 "tf.data service processing mode should be a ShardingPolicy, " 162 "`\"parallel_epochs\"`, or `\"distributed_epoch\"`. Got " 163 f"{processing_mode!r}.") 164 165 166def _validate_job_name(job_name): 167 if job_name is None: 168 return 169 if not isinstance(job_name, six.string_types): 170 raise ValueError("job_name must be a string, but job_name was of type " 171 "{0}. job_name={1}".format(type(job_name), job_name)) 172 if not job_name: 173 raise ValueError("job_name must not be empty") 174 175 176class _DataServiceDatasetV2(dataset_ops.DatasetSource): 177 """A `Dataset` that reads elements from the tf.data service.""" 178 179 def __init__(self, 180 dataset_id, 181 processing_mode, 182 address, 183 element_spec, 184 protocol, 185 data_transfer_protocol, 186 job_name=None, 187 consumer_index=None, 188 num_consumers=None, 189 max_outstanding_requests=None, 190 task_refresh_interval_hint_ms=None, 191 target_workers="AUTO"): 192 """Constructs a _DataServiceDatasetV2. 193 194 Args: 195 dataset_id: The dataset id for the dataset to read from. 196 processing_mode: A `tf.data.experimental.service.ShardingPolicy` 197 specifying how to shard the dataset among tf.data workers. See 198 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 199 compatibility, `processing_mode` may also be set to the strings 200 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 201 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 202 address: The tf.data service address, e.g. "localhost:5000". 203 element_spec: The dataset element spec for the dataset to read from. 204 protocol: The protocol to use for communicating with the tf.data service, 205 e.g. "grpc". 206 data_transfer_protocol: (Optional.) The protocol to use for transferring 207 data with the tf.data service. By default, data is transferred using 208 gRPC. 209 job_name: (Optional.) The name of the job. If provided, it must be a 210 non-empty string or Tensor. This argument makes it possible 211 for multiple datasets to share the same job. The default behavior is 212 that the dataset creates anonymous, exclusively owned jobs. 213 consumer_index: (Optional.) The index of the consumer in the range from 214 `0` to `num_consumers`. Must be specified alongside `num_consumers`. 215 When specified, consumers will read from the job in a strict round-robin 216 order, instead of the default first-come-first-served order. 217 num_consumers: (Optional.) The number of consumers which will consume from 218 the job. Must be specified alongside `consumer_index`. When specified, 219 consumers will read from the job in a strict round-robin order, instead 220 of the default first-come-first-served order. When `num_consumers` is 221 specified, the dataset must have infinite cardinality to prevent a 222 producer from running out of data early and causing consumers to go out 223 of sync. 224 max_outstanding_requests: (Optional.) A limit on how many elements may be 225 requested at the same time. You can use this option to control the 226 amount of memory used, since `distribute` won't use more than 227 `element_size` * `max_outstanding_requests` of memory. 228 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query 229 the dispatcher for task changes. 230 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, 231 tf.data runtime decides which workers to read from. If `"ANY"`, reads 232 from any tf.data service workers. If `"LOCAL"`, only reads from local 233 in-processs tf.data service workers. `"AUTO"` works well for most cases, 234 while users can specify other targets. For example, `"LOCAL"` helps 235 avoid RPCs and data copy if every TF worker colocates with a tf.data 236 service worker. Consumers of a shared job must use the same 237 `target_workers`. Defaults to `"AUTO"`. 238 """ 239 processing_mode = _serialize( 240 _get_validated_sharding_policy(processing_mode)) 241 if consumer_index is None != num_consumers is None: 242 raise ValueError( 243 "Must either set both consumer_index and num_consumers, or neither. ", 244 "consumer_index: ", consumer_index, ", num_consumers: ", 245 num_consumers) 246 if num_consumers is not None and job_name is None: 247 raise ValueError("job_name must be set when setting num_consumers") 248 249 if job_name is None: 250 job_name = "" 251 if max_outstanding_requests is None: 252 max_outstanding_requests = dataset_ops.AUTOTUNE 253 if task_refresh_interval_hint_ms is None: 254 task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE 255 256 self._dataset_id = ops.convert_to_tensor( 257 dataset_id, dtype=dtypes.int64, name="dataset_id") 258 self._processing_mode = ops.convert_to_tensor( 259 processing_mode, dtype=dtypes.string, name="processing_mode") 260 self._address = ops.convert_to_tensor( 261 address, dtype=dtypes.string, name="address") 262 self._protocol = ops.convert_to_tensor( 263 protocol, dtype=dtypes.string, name="protocol") 264 self._job_name = ops.convert_to_tensor( 265 job_name, dtype=dtypes.string, name="job_name") 266 self._consumer_index = ops.convert_to_tensor( 267 -1 if consumer_index is None else consumer_index, 268 dtype=dtypes.int64, 269 name="consumer_index") 270 self._num_consumers = ops.convert_to_tensor( 271 -1 if num_consumers is None else num_consumers, 272 dtype=dtypes.int64, 273 name="num_consumers") 274 self._max_outstanding_requests = ops.convert_to_tensor( 275 max_outstanding_requests, 276 dtype=dtypes.int64, 277 name="max_outstanding_requests") 278 self._element_spec = element_spec 279 self._target_workers = target_workers 280 281 compat_kwargs = {} 282 if data_transfer_protocol is not None: 283 compat_kwargs["data_transfer_protocol"] = data_transfer_protocol 284 if compat.forward_compatible(2021, 7, 12) or target_workers != "AUTO": 285 compat_kwargs["target_workers"] = target_workers 286 287 variant_tensor = gen_experimental_dataset_ops.data_service_dataset_v2( 288 dataset_id=self._dataset_id, 289 processing_mode=self._processing_mode, 290 address=self._address, 291 protocol=self._protocol, 292 job_name=self._job_name, 293 consumer_index=self._consumer_index, 294 num_consumers=self._num_consumers, 295 max_outstanding_requests=self._max_outstanding_requests, 296 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 297 iteration_counter=gen_experimental_dataset_ops.dummy_iteration_counter( 298 ), 299 **compat_kwargs, 300 **self._flat_structure) 301 super(_DataServiceDatasetV2, self).__init__(variant_tensor) 302 303 @property 304 def element_spec(self): 305 return self._element_spec 306 307 308class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter): 309 """A `Dataset` that executes its input through the tf.data service.""" 310 311 @functools.wraps(_DataServiceDatasetV2.__init__) 312 def __init__(self, dataset_id, processing_mode, address, element_spec, 313 protocol, data_transfer_protocol, job_name, consumer_index, 314 num_consumers, max_outstanding_requests, 315 task_refresh_interval_hint_ms, target_workers): 316 317 self._wrapped = _DataServiceDatasetV2( 318 dataset_id=dataset_id, 319 processing_mode=processing_mode, 320 address=address, 321 element_spec=element_spec, 322 protocol=protocol, 323 data_transfer_protocol=data_transfer_protocol, 324 job_name=job_name, 325 consumer_index=consumer_index, 326 num_consumers=num_consumers, 327 max_outstanding_requests=max_outstanding_requests, 328 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 329 target_workers=target_workers) 330 super(_DataServiceDatasetV1, self).__init__(self._wrapped) 331 332 333if tf2.enabled(): 334 _DataServiceDataset = _DataServiceDatasetV2 335else: 336 _DataServiceDataset = _DataServiceDatasetV1 337 338 339def _parse_service(service): 340 """Converts a tf.data service string into a (protocol, address) tuple. 341 342 Args: 343 service: A string in the format "protocol://address" or just "address". If 344 the string is only an address, the default protocol will be used. 345 346 Returns: 347 The (protocol, address) tuple 348 """ 349 if not isinstance(service, six.string_types): 350 raise ValueError( 351 "service must be a string, but service was of type {0}. service={1}" 352 .format(type(service), service)) 353 if not service: 354 raise ValueError("service must not be empty") 355 parts = service.split("://") 356 if len(parts) == 2: 357 protocol, address = parts 358 elif len(parts) == 1: 359 address = parts[0] 360 protocol = _pywrap_utils.TF_DATA_DefaultProtocol() 361 else: 362 raise ValueError("malformed service string has multiple '://': %s" % 363 service) 364 # TODO(aaudibert): Considering validating reachability of address here. 365 return (protocol, address) 366 367 368def _distribute(processing_mode, 369 service, 370 job_name=None, 371 consumer_index=None, 372 num_consumers=None, 373 max_outstanding_requests=None, 374 task_refresh_interval_hint_ms=None, 375 data_transfer_protocol=None, 376 compression="AUTO", 377 target_workers="AUTO"): 378 """A transformation that moves dataset processing to the tf.data service. 379 380 This transformation is similar to `distribute`, but supports additional 381 parameters which we do not yet want to add to the public Python API. 382 383 Args: 384 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying 385 how to shard the dataset among tf.data workers. See 386 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 387 compatibility, `processing_mode` may also be set to the strings 388 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 389 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 390 service: A string or a tuple indicating how to connect to the tf.data 391 service. If it's a string, it should be in the format 392 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 393 address and `<protocol>` can optionally be used to override the default 394 protocol to use. If it's a tuple, it should be (protocol, address). 395 job_name: (Optional.) The name of the job. If provided, it must be a 396 non-empty string. This argument makes it possible 397 for multiple datasets to share the same job. The default behavior is that 398 the dataset creates anonymous, exclusively owned jobs. 399 consumer_index: (Optional.) The index of the consumer in the range from `0` 400 to `num_consumers`. Must be specified alongside `num_consumers`. When 401 specified, consumers will read from the job in a strict round-robin order, 402 instead of the default first-come-first-served order. 403 num_consumers: (Optional.) The number of consumers which will consume from 404 the job. Must be specified alongside `consumer_index`. When specified, 405 consumers will read from the job in a strict round-robin order, instead of 406 the default first-come-first-served order. When `num_consumers` is 407 specified, the dataset must have infinite cardinality to prevent a 408 producer from running out of data early and causing consumers to go out of 409 sync. 410 max_outstanding_requests: (Optional.) A limit on how many elements may be 411 requested at the same time. You can use this option to control the amount 412 of memory used, since `distribute` won't use more than `element_size` * 413 `max_outstanding_requests` of memory. 414 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the 415 dispatcher for task changes. 416 data_transfer_protocol: (Optional.) The protocol to use for transferring 417 data with the tf.data service. By default, data is transferred using gRPC. 418 compression: How to compress the dataset's elements before transferring them 419 over the network. "AUTO" leaves the decision of how to compress up to the 420 tf.data service runtime. `None` indicates not to compress. 421 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data 422 runtime decides which workers to read from. If `"ANY"`, reads from any 423 tf.data service workers. If `"LOCAL"`, only reads from local in-processs 424 tf.data service workers. `"AUTO"` works well for most cases, while users 425 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and 426 data copy if every TF worker colocates with a tf.data service worker. 427 Consumers of a shared job must use the same `target_workers`. Defaults 428 to `"AUTO"`. 429 430 Returns: 431 Dataset: A `Dataset` of the elements produced by the data service. 432 """ 433 processing_mode = _get_validated_sharding_policy(processing_mode) 434 valid_compressions = [COMPRESSION_AUTO, COMPRESSION_NONE] 435 if compression not in valid_compressions: 436 raise ValueError( 437 "Invalid compression argument: {}. Must be one of {}".format( 438 compression, valid_compressions)) 439 if compression == COMPRESSION_AUTO and data_transfer_protocol is not None: 440 compression = COMPRESSION_NONE 441 def _apply_fn(dataset): # pylint: disable=missing-docstring 442 dataset_id = _register_dataset(service, dataset, compression=compression) 443 return _from_dataset_id( 444 processing_mode, 445 service, 446 dataset_id, 447 dataset.element_spec, 448 job_name=job_name, 449 consumer_index=consumer_index, 450 num_consumers=num_consumers, 451 max_outstanding_requests=max_outstanding_requests, 452 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 453 data_transfer_protocol=data_transfer_protocol, 454 compression=compression, 455 target_workers=target_workers) 456 457 return _apply_fn 458 459 460@tf_export("data.experimental.service.distribute") 461def distribute(processing_mode, 462 service, 463 job_name=None, 464 consumer_index=None, 465 num_consumers=None, 466 max_outstanding_requests=None, 467 data_transfer_protocol=None, 468 compression="AUTO", 469 target_workers="AUTO"): 470 """A transformation that moves dataset processing to the tf.data service. 471 472 When you iterate over a dataset containing the `distribute` transformation, 473 the tf.data service creates a "job" which produces data for the dataset 474 iteration. 475 476 The tf.data service uses a cluster of workers to prepare data for training 477 your model. 478 The `processing_mode` argument to `tf.data.experimental.service.distribute` 479 describes how to leverage multiple workers to process the input dataset. 480 Currently, there are two processing modes to choose from: "distributed_epoch" 481 and "parallel_epochs". 482 483 "distributed_epoch" means that the dataset will be split across all tf.data 484 service workers. 485 The dispatcher produces "splits" for the dataset and sends them to workers for 486 further processing. For example, if a dataset begins with a list of filenames, 487 the dispatcher will iterate through the filenames and send the filenames to 488 tf.data workers, which will perform the rest of the dataset transformations on 489 those files. "distributed_epoch" is useful when your model needs to see each 490 element of the dataset exactly once, or if it needs to see the data in a 491 generally-sequential order. "distributed_epoch" only works for datasets with 492 splittable sources, such as `Dataset.from_tensor_slices`, 493 `Dataset.list_files`, or `Dataset.range`. 494 495 "parallel_epochs" means that the entire input dataset will be processed 496 independently by each of the tf.data service workers. 497 For this reason, it is important to shuffle data (e.g. filenames) 498 non-deterministically, so that each worker will process the elements of the 499 dataset in a different order. "parallel_epochs" can be used to distribute 500 datasets that aren't splittable. 501 502 With two workers, "parallel_epochs" will produce every element of the dataset 503 twice: 504 505 >>> dispatcher = tf.data.experimental.service.DispatchServer() 506 >>> dispatcher_address = dispatcher.target.split("://")[1] 507 >>> # Start two workers 508 >>> workers = [ 509 ... tf.data.experimental.service.WorkerServer( 510 ... tf.data.experimental.service.WorkerConfig( 511 ... dispatcher_address=dispatcher_address)) for _ in range(2) 512 ... ] 513 >>> dataset = tf.data.Dataset.range(10) 514 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 515 ... processing_mode="parallel_epochs", service=dispatcher.target)) 516 >>> print(sorted(list(dataset.as_numpy_iterator()))) 517 [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9] 518 519 "distributed_epoch", on the other hand, will still produce each element once: 520 521 >>> dispatcher = tf.data.experimental.service.DispatchServer() 522 >>> dispatcher_address = dispatcher.target.split("://")[1] 523 >>> workers = [ 524 ... tf.data.experimental.service.WorkerServer( 525 ... tf.data.experimental.service.WorkerConfig( 526 ... dispatcher_address=dispatcher_address)) for _ in range(2) 527 ... ] 528 >>> dataset = tf.data.Dataset.range(10) 529 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 530 ... processing_mode="distributed_epoch", service=dispatcher.target)) 531 >>> print(sorted(list(dataset.as_numpy_iterator()))) 532 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 533 534 When using `apply(tf.data.experimental.service.distribute(...))`, the dataset 535 before the `apply` transformation executes within the tf.data service, while 536 the operations after `apply` happen within the local process. 537 538 >>> dispatcher = tf.data.experimental.service.DispatchServer() 539 >>> dispatcher_address = dispatcher.target.split("://")[1] 540 >>> workers = [ 541 ... tf.data.experimental.service.WorkerServer( 542 ... tf.data.experimental.service.WorkerConfig( 543 ... dispatcher_address=dispatcher_address)) for _ in range(2) 544 ... ] 545 >>> dataset = tf.data.Dataset.range(5) 546 >>> dataset = dataset.map(lambda x: x*x) 547 >>> dataset = dataset.apply( 548 ... tf.data.experimental.service.distribute("parallel_epochs", 549 ... dispatcher.target)) 550 >>> dataset = dataset.map(lambda x: x+1) 551 >>> print(sorted(list(dataset.as_numpy_iterator()))) 552 [1, 1, 2, 2, 5, 5, 10, 10, 17, 17] 553 554 In the above example, the dataset operations (before applying the `distribute` 555 function on the elements) will be executed on the tf.data workers, 556 and the elements are provided over RPC. The remaining transformations 557 (after the call to `distribute`) will be executed locally. The dispatcher 558 and the workers will bind to usused free ports (which are chosen at random), 559 in order to communicate with each other. However, to bind them to specific 560 ports, the `port` parameter can be passed. 561 562 The `job_name` argument allows jobs to be shared across multiple 563 datasets. Instead of each dataset creating its own job, all 564 datasets with the same `job_name` will consume from the same job. A new job 565 will be created for each iteration of the dataset (with each repetition of 566 `Dataset.repeat` counting as a new iteration). Suppose the `DispatchServer` 567 is serving on `localhost:5000` and two training workers (in either a single 568 client or multi-client setup) iterate over the below dataset, and there is a 569 single tf.data worker: 570 571 ``` 572 range5_dataset = tf.data.Dataset.range(5) 573 dataset = range5_dataset.apply(tf.data.experimental.service.distribute( 574 "parallel_epochs", "localhost:5000", job_name="my_job_name")) 575 for iteration in range(3): 576 print(list(dataset)) 577 ``` 578 579 The elements of each job will be split between the two processes, with 580 elements being consumed by the processes on a first-come first-served basis. 581 One possible result is that process 1 prints 582 583 ``` 584 [0, 2, 4] 585 [0, 1, 3] 586 [1] 587 ``` 588 589 and process 2 prints 590 591 ``` 592 [1, 3] 593 [2, 4] 594 [0, 2, 3, 4] 595 ``` 596 597 Job names must not be re-used across different training jobs within the 598 lifetime of the tf.data service. In general, the tf.data service is expected 599 to live for the duration of a single training job. 600 To use the tf.data service with multiple training jobs, make sure to use 601 different job names to avoid conflicts. For example, suppose a training job 602 calls `distribute` with `job_name="job"` and reads until end of input. If 603 another independent job connects to the same tf.data service and tries to read 604 from `job_name="job"`, it will immediately receive end of input, without 605 getting any data. 606 607 **Round Robin data consumption** 608 609 By default, when multiple consumers read from the same job, they receive data 610 on a first-come first-served basis. In some use cases, it works better to use 611 a strict round-robin order. For example, the tf.data service can be used to 612 coordinate example sizes across a cluster during sychronous training, so that 613 during each step all replicas train on similar-sized elements. To achieve 614 this, define a dataset which generates rounds of `num_consumers` consecutive 615 similar-sized batches, then enable round-robin reads by setting 616 `consumer_index` and `num_consumers`. 617 618 Consumers read data by cycling through all workers, reading one element from 619 each. First, each consumer will read an element from the first worker, then 620 each consumer will read an element from the second worker, and so on. 621 622 NOTE: To keep consumers in sync, round robin data consumption requires that 623 the dataset have infinite cardinality. You can get this by adding `.repeat()` 624 at the end of the dataset definition. 625 626 **Keras and Distribution Strategies** 627 628 The dataset produced by the `distribute` transformation can be passed to 629 Keras' `Model.fit` or Distribution Strategy's 630 `tf.distribute.Strategy.experimental_distribute_dataset` like any other 631 `tf.data.Dataset`. We recommend setting a `job_name` on the call to 632 `distribute` so that if there are multiple workers, they read data from the 633 same job. Note that the autosharding normally performed by 634 `experimental_distribute_dataset` will be disabled when setting a `job_name`, 635 since sharing the job already results in splitting data across the workers. 636 When using a shared job, data will be dynamically balanced across workers, so 637 that they reach end of input about the same time. This results in better 638 worker utilization than with autosharding, where each worker processes an 639 independent set of files, and some workers may run out of data earlier than 640 others. 641 642 Args: 643 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying 644 how to shard the dataset among tf.data workers. See 645 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 646 compatibility, `processing_mode` may also be set to the strings 647 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 648 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 649 service: A string or a tuple indicating how to connect to the tf.data 650 service. If it's a string, it should be in the format 651 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 652 address and `<protocol>` can optionally be used to override the default 653 protocol to use. If it's a tuple, it should be (protocol, address). 654 job_name: (Optional.) The name of the job. If provided, it must be a 655 non-empty string. This argument makes it possible 656 for multiple datasets to share the same job. The default behavior is that 657 the dataset creates anonymous, exclusively owned jobs. 658 consumer_index: (Optional.) The index of the consumer in the range from `0` 659 to `num_consumers`. Must be specified alongside `num_consumers`. When 660 specified, consumers will read from the job in a strict round-robin order, 661 instead of the default first-come-first-served order. 662 num_consumers: (Optional.) The number of consumers which will consume from 663 the job. Must be specified alongside `consumer_index`. When specified, 664 consumers will read from the job in a strict round-robin order, instead of 665 the default first-come-first-served order. When `num_consumers` is 666 specified, the dataset must have infinite cardinality to prevent a 667 producer from running out of data early and causing consumers to go out of 668 sync. 669 max_outstanding_requests: (Optional.) A limit on how many elements may be 670 requested at the same time. You can use this option to control the amount 671 of memory used, since `distribute` won't use more than `element_size` * 672 `max_outstanding_requests` of memory. 673 data_transfer_protocol: (Optional.) The protocol to use for transferring 674 data with the tf.data service. By default, data is transferred using gRPC. 675 compression: How to compress the dataset's elements before transferring them 676 over the network. "AUTO" leaves the decision of how to compress up to the 677 tf.data service runtime. `None` indicates not to compress. 678 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data 679 runtime decides which workers to read from. If `"ANY"`, reads from any 680 tf.data service workers. If `"LOCAL"`, only reads from local in-processs 681 tf.data service workers. `"AUTO"` works well for most cases, while users 682 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and 683 data copy if every TF worker colocates with a tf.data service worker. 684 Consumers of a shared job must use the same `target_workers`. Defaults 685 to `"AUTO"`. 686 687 Returns: 688 Dataset: A `Dataset` of the elements produced by the data service. 689 """ 690 _validate_job_name(job_name) 691 return _distribute( 692 processing_mode=processing_mode, 693 service=service, 694 job_name=job_name, 695 consumer_index=consumer_index, 696 num_consumers=num_consumers, 697 max_outstanding_requests=max_outstanding_requests, 698 data_transfer_protocol=data_transfer_protocol, 699 compression=compression, 700 target_workers=target_workers) 701 702 703def _register_dataset(service, dataset, compression): 704 """Registers a dataset with the tf.data service. 705 706 This transformation is similar to `register_dataset`, but supports additional 707 parameters which we do not yet want to add to the public Python API. 708 709 Args: 710 service: A string or a tuple indicating how to connect to the tf.data 711 service. If it's a string, it should be in the format 712 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 713 address and `<protocol>` can optionally be used to override the default 714 protocol to use. If it's a tuple, it should be (protocol, address). 715 dataset: A `tf.data.Dataset` to register with the tf.data service. 716 compression: How to compress the dataset's elements before transferring them 717 over the network. "AUTO" leaves the decision of how to compress up to the 718 tf.data service runtime. `None` indicates not to compress. 719 720 Returns: 721 A scalar int64 tensor of the registered dataset's id. 722 """ 723 valid_compressions = [COMPRESSION_AUTO, COMPRESSION_NONE] 724 if compression not in valid_compressions: 725 raise ValueError( 726 "Invalid compression argument: {}. Must be one of {}".format( 727 compression, valid_compressions)) 728 if isinstance(service, tuple): 729 protocol, address = service 730 else: 731 protocol, address = _parse_service(service) 732 external_state_policy = dataset.options().experimental_external_state_policy 733 if external_state_policy is None: 734 external_state_policy = ExternalStatePolicy.WARN 735 736 encoded_spec = "" 737 if context.executing_eagerly(): 738 coder = nested_structure_coder.StructureCoder() 739 encoded_spec = coder.encode_structure( 740 dataset.element_spec).SerializeToString() 741 742 if compression == COMPRESSION_AUTO: 743 dataset = dataset.map( 744 lambda *x: compression_ops.compress(x), 745 num_parallel_calls=dataset_ops.AUTOTUNE) 746 dataset = dataset.prefetch(dataset_ops.AUTOTUNE) 747 dataset = dataset._apply_debug_options() # pylint: disable=protected-access 748 749 dataset_id = gen_experimental_dataset_ops.register_dataset( 750 dataset._variant_tensor, # pylint: disable=protected-access 751 address=address, 752 protocol=protocol, 753 external_state_policy=external_state_policy.value, 754 element_spec=encoded_spec) 755 756 return dataset_id 757 758 759@tf_export("data.experimental.service.register_dataset") 760def register_dataset(service, dataset): 761 """Registers a dataset with the tf.data service. 762 763 `register_dataset` registers a dataset with the tf.data service so that 764 datasets can be created later with 765 `tf.data.experimental.service.from_dataset_id`. This is useful when the 766 dataset 767 is registered by one process, then used in another process. When the same 768 process is both registering and reading from the dataset, it is simpler to use 769 `tf.data.experimental.service.distribute` instead. 770 771 If the dataset is already registered with the tf.data service, 772 `register_dataset` returns the already-registered dataset's id. 773 774 >>> dispatcher = tf.data.experimental.service.DispatchServer() 775 >>> dispatcher_address = dispatcher.target.split("://")[1] 776 >>> worker = tf.data.experimental.service.WorkerServer( 777 ... tf.data.experimental.service.WorkerConfig( 778 ... dispatcher_address=dispatcher_address)) 779 >>> dataset = tf.data.Dataset.range(10) 780 >>> dataset_id = tf.data.experimental.service.register_dataset( 781 ... dispatcher.target, dataset) 782 >>> dataset = tf.data.experimental.service.from_dataset_id( 783 ... processing_mode="parallel_epochs", 784 ... service=dispatcher.target, 785 ... dataset_id=dataset_id, 786 ... element_spec=dataset.element_spec) 787 >>> print(list(dataset.as_numpy_iterator())) 788 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 789 790 Args: 791 service: A string or a tuple indicating how to connect to the tf.data 792 service. If it's a string, it should be in the format 793 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 794 address and `<protocol>` can optionally be used to override the default 795 protocol to use. If it's a tuple, it should be (protocol, address). 796 dataset: A `tf.data.Dataset` to register with the tf.data service. 797 798 Returns: 799 A scalar int64 tensor of the registered dataset's id. 800 """ 801 return _register_dataset(service, dataset, compression="AUTO") 802 803 804def _from_dataset_id(processing_mode, 805 service, 806 dataset_id, 807 element_spec, 808 job_name=None, 809 consumer_index=None, 810 num_consumers=None, 811 max_outstanding_requests=None, 812 task_refresh_interval_hint_ms=None, 813 data_transfer_protocol=None, 814 compression="AUTO", 815 target_workers="AUTO"): 816 """Creates a dataset which reads data from the tf.data service. 817 818 This transformation is similar to `from_dataset_id`, but supports additional 819 parameters which we do not yet want to add to the public Python API. 820 821 Args: 822 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying 823 how to shard the dataset among tf.data workers. See 824 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 825 compatibility, `processing_mode` may also be set to the strings 826 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 827 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 828 service: A string or a tuple indicating how to connect to the tf.data 829 service. If it's a string, it should be in the format 830 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 831 address and `<protocol>` can optionally be used to override the default 832 protocol to use. If it's a tuple, it should be (protocol, address). 833 dataset_id: The id of the dataset to read from. This id is returned by 834 `register_dataset` when the dataset is registered with the tf.data 835 service. 836 element_spec: A nested structure of `tf.TypeSpec`s representing the type of 837 elements produced by the dataset. This argument is only required inside a 838 tf.function. Use `tf.data.Dataset.element_spec` to get the element spec 839 for a given dataset. 840 job_name: (Optional.) The name of the job. If provided, it must be a 841 non-empty string or tensor. This argument makes it possible 842 for multiple datasets to share the same job. The default behavior is that 843 the dataset creates anonymous, exclusively owned jobs. 844 consumer_index: (Optional.) The index of the consumer in the range from `0` 845 to `num_consumers`. Must be specified alongside `num_consumers`. When 846 specified, consumers will read from the job in a strict round-robin order, 847 instead of the default first-come-first-served order. 848 num_consumers: (Optional.) The number of consumers which will consume from 849 the job. Must be specified alongside `consumer_index`. When specified, 850 consumers will read from the job in a strict round-robin order, instead of 851 the default first-come-first-served order. When `num_consumers` is 852 specified, the dataset must have infinite cardinality to prevent a 853 producer from running out of data early and causing consumers to go out of 854 sync. 855 max_outstanding_requests: (Optional.) A limit on how many elements may be 856 requested at the same time. You can use this option to control the amount 857 of memory used, since `distribute` won't use more than `element_size` * 858 `max_outstanding_requests` of memory. 859 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the 860 dispatcher for task changes. 861 data_transfer_protocol: (Optional.) The protocol to use for transferring 862 data with the tf.data service. By default, data is transferred using gRPC. 863 compression: An indication of how the dataset's elements were compressed, so 864 that `from_dataset_id` can uncompress them if necessary. 865 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data 866 runtime decides which workers to read from. If `"ANY"`, reads from any 867 tf.data service workers. If `"LOCAL"`, only reads from local in-processs 868 tf.data service workers. `"AUTO"` works well for most cases, while users 869 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and 870 data copy if every TF worker colocates with a tf.data service worker. 871 Consumers of a shared job must use the same `target_workers`. Defaults 872 to `"AUTO"`. 873 874 Returns: 875 A `tf.data.Dataset` which reads from the tf.data service. 876 """ 877 processing_mode = _get_validated_sharding_policy(processing_mode) 878 valid_compressions = [COMPRESSION_AUTO, COMPRESSION_NONE] 879 if isinstance(service, tuple): 880 protocol, address = service 881 else: 882 protocol, address = _parse_service(service) 883 884 if compression not in valid_compressions: 885 raise ValueError( 886 "Invalid compression argument: {}. Must be one of {}".format( 887 compression, valid_compressions)) 888 if job_name is not None: 889 if not isinstance(job_name, six.string_types) and not isinstance( 890 job_name, ops.Tensor): 891 raise ValueError( 892 "job_name must be a string or Tensor, but job_name was of type " 893 "{0}. job_name={1}".format(type(job_name), job_name)) 894 895 if element_spec is None: 896 if not context.executing_eagerly(): 897 raise ValueError("In graph mode element_spec must be provided manually.") 898 899 dataset_id_val = tensor_util.constant_value(dataset_id) 900 try: 901 encoded_spec = _pywrap_server_lib.TF_DATA_GetElementSpec( 902 dataset_id_val, address, protocol) 903 904 except NotImplementedError as err: 905 raise ValueError("The tf.data service is running an earlier version of " 906 "TensorFlow that requires specifying `element_spec` as " 907 "an argument to `from_dataset_id`. Please either supply " 908 "an element spec or update the tf.data service to the " 909 "latest version.") from err 910 911 except RuntimeError as err: 912 raise ValueError("Failed to fetch element spec for dataset id " + 913 str(dataset_id_val) + " from tf.data service. If the " 914 "dataset was registered in graph mode or inside a " 915 "tf.function, the `element_spec` must be specified as " 916 "an argument to `from_dataset_id`.") from err 917 918 struct_pb = nested_structure_coder.struct_pb2.StructuredValue() 919 struct_pb.ParseFromString(encoded_spec) 920 coder = nested_structure_coder.StructureCoder() 921 element_spec = coder.decode_proto(struct_pb) 922 923 # If we compress, the data service side dataset will produce scalar variants. 924 data_service_element_spec = ( 925 tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) 926 if compression == COMPRESSION_AUTO else element_spec) 927 928 dataset = _DataServiceDataset( 929 dataset_id=dataset_id, 930 processing_mode=processing_mode, 931 address=address, 932 element_spec=data_service_element_spec, 933 protocol=protocol, 934 data_transfer_protocol=data_transfer_protocol, 935 job_name=job_name, 936 consumer_index=consumer_index, 937 num_consumers=num_consumers, 938 max_outstanding_requests=max_outstanding_requests, 939 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 940 target_workers=target_workers) 941 if compression == COMPRESSION_AUTO: 942 dataset = dataset.map( 943 lambda x: compression_ops.uncompress(x, output_spec=element_spec), 944 num_parallel_calls=dataset_ops.AUTOTUNE) 945 946 # Disable autosharding for shared jobs. 947 if job_name is not None: 948 options = options_lib.Options() 949 options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF 950 dataset = dataset.with_options(options) 951 return dataset 952 953 954@tf_export("data.experimental.service.from_dataset_id") 955def from_dataset_id(processing_mode, 956 service, 957 dataset_id, 958 element_spec=None, 959 job_name=None, 960 consumer_index=None, 961 num_consumers=None, 962 max_outstanding_requests=None, 963 data_transfer_protocol=None, 964 target_workers="AUTO"): 965 """Creates a dataset which reads data from the tf.data service. 966 967 This is useful when the dataset is registered by one process, then used in 968 another process. When the same process is both registering and reading from 969 the dataset, it is simpler to use `tf.data.experimental.service.distribute` 970 instead. 971 972 Before using `from_dataset_id`, the dataset must have been registered with the 973 tf.data service using `tf.data.experimental.service.register_dataset`. 974 `register_dataset` returns a dataset id for the registered dataset. That is 975 the `dataset_id` which should be passed to `from_dataset_id`. 976 977 The `element_spec` argument indicates the `tf.TypeSpec`s for the elements 978 produced by the dataset. Currently `element_spec` must be explicitly 979 specified, and match the dataset registered under `dataset_id`. `element_spec` 980 defaults to `None` so that in the future we can support automatically 981 discovering the `element_spec` by querying the tf.data service. 982 983 `tf.data.experimental.service.distribute` is a convenience method which 984 combines `register_dataset` and `from_dataset_id` into a dataset 985 transformation. 986 See the documentation for `tf.data.experimental.service.distribute` for more 987 detail about how `from_dataset_id` works. 988 989 >>> dispatcher = tf.data.experimental.service.DispatchServer() 990 >>> dispatcher_address = dispatcher.target.split("://")[1] 991 >>> worker = tf.data.experimental.service.WorkerServer( 992 ... tf.data.experimental.service.WorkerConfig( 993 ... dispatcher_address=dispatcher_address)) 994 >>> dataset = tf.data.Dataset.range(10) 995 >>> dataset_id = tf.data.experimental.service.register_dataset( 996 ... dispatcher.target, dataset) 997 >>> dataset = tf.data.experimental.service.from_dataset_id( 998 ... processing_mode="parallel_epochs", 999 ... service=dispatcher.target, 1000 ... dataset_id=dataset_id, 1001 ... element_spec=dataset.element_spec) 1002 >>> print(list(dataset.as_numpy_iterator())) 1003 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 1004 1005 Args: 1006 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying 1007 how to shard the dataset among tf.data workers. See 1008 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 1009 compatibility, `processing_mode` may also be set to the strings 1010 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 1011 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 1012 service: A string or a tuple indicating how to connect to the tf.data 1013 service. If it's a string, it should be in the format 1014 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 1015 address and `<protocol>` can optionally be used to override the default 1016 protocol to use. If it's a tuple, it should be (protocol, address). 1017 dataset_id: The id of the dataset to read from. This id is returned by 1018 `register_dataset` when the dataset is registered with the tf.data 1019 service. 1020 element_spec: A nested structure of `tf.TypeSpec`s representing the type of 1021 elements produced by the dataset. This argument is only required inside a 1022 tf.function. Use `tf.data.Dataset.element_spec` to get the element spec 1023 for a given dataset. 1024 job_name: (Optional.) The name of the job. If provided, it must be a 1025 non-empty string. This argument makes it possible 1026 for multiple datasets to share the same job. The default behavior is that 1027 the dataset creates anonymous, exclusively owned jobs. 1028 consumer_index: (Optional.) The index of the consumer in the range from `0` 1029 to `num_consumers`. Must be specified alongside `num_consumers`. When 1030 specified, consumers will read from the job in a strict round-robin order, 1031 instead of the default first-come-first-served order. 1032 num_consumers: (Optional.) The number of consumers which will consume from 1033 the job. Must be specified alongside `consumer_index`. When specified, 1034 consumers will read from the job in a strict round-robin order, instead of 1035 the default first-come-first-served order. When `num_consumers` is 1036 specified, the dataset must have infinite cardinality to prevent a 1037 producer from running out of data early and causing consumers to go out of 1038 sync. 1039 max_outstanding_requests: (Optional.) A limit on how many elements may be 1040 requested at the same time. You can use this option to control the amount 1041 of memory used, since `distribute` won't use more than `element_size` * 1042 `max_outstanding_requests` of memory. 1043 data_transfer_protocol: (Optional.) The protocol to use for transferring 1044 data with the tf.data service. By default, data is transferred using gRPC. 1045 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data 1046 runtime decides which workers to read from. If `"ANY"`, reads from any 1047 tf.data service workers. If `"LOCAL"`, only reads from local in-processs 1048 tf.data service workers. `"AUTO"` works well for most cases, while users 1049 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and 1050 data copy if every TF worker colocates with a tf.data service worker. 1051 Consumers of a shared job must use the same `target_workers`. Defaults 1052 to `"AUTO"`. 1053 1054 Returns: 1055 A `tf.data.Dataset` which reads from the tf.data service. 1056 """ 1057 _validate_job_name(job_name) 1058 if job_name is not None: 1059 job_name = string_ops.string_join( 1060 ["dataset_id=", string_ops.as_string(dataset_id), job_name], "/") 1061 1062 return _from_dataset_id( 1063 processing_mode=processing_mode, 1064 service=service, 1065 dataset_id=dataset_id, 1066 element_spec=element_spec, 1067 job_name=job_name, 1068 consumer_index=consumer_index, 1069 num_consumers=num_consumers, 1070 max_outstanding_requests=max_outstanding_requests, 1071 data_transfer_protocol=data_transfer_protocol, 1072 target_workers=target_workers) 1073