• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python API for executing a tf.data.Dataset using a tf.data service."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import enum
21import functools
22import six
23
24from tensorflow.core.protobuf import data_service_pb2
25from tensorflow.python import tf2
26from tensorflow.python.compat import compat
27from tensorflow.python.data.experimental.ops import compression_ops
28from tensorflow.python.data.experimental.service import _pywrap_server_lib
29from tensorflow.python.data.experimental.service import _pywrap_utils
30from tensorflow.python.data.ops import dataset_ops
31from tensorflow.python.data.ops import options as options_lib
32from tensorflow.python.data.ops.options import AutoShardPolicy
33from tensorflow.python.data.ops.options import ExternalStatePolicy
34from tensorflow.python.eager import context
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import tensor_spec
38from tensorflow.python.framework import tensor_util
39from tensorflow.python.ops import gen_experimental_dataset_ops
40from tensorflow.python.ops import string_ops
41from tensorflow.python.util import lazy_loader
42from tensorflow.python.util.tf_export import tf_export
43
44COMPRESSION_AUTO = "AUTO"
45COMPRESSION_NONE = None
46_PARALLEL_EPOCHS = "parallel_epochs"
47_DISTRIBUTED_EPOCH = "distributed_epoch"
48
49# TODO(b/176933539): Use the regular import.
50nested_structure_coder = lazy_loader.LazyLoader(
51    "nested_structure_coder", globals(),
52    "tensorflow.python.saved_model.nested_structure_coder")
53
54
55@tf_export("data.experimental.service.ShardingPolicy")
56class ShardingPolicy(enum.IntEnum):
57  """Specifies how to shard data among tf.data service workers.
58
59  OFF: No sharding will be performed. Each worker produces the entire dataset
60  without any sharding. With this mode, the best practice is to shuffle the
61  dataset nondeterministically so that workers process the dataset in different
62  orders. If workers are restarted or join the cluster mid-job, they will begin
63  processing the dataset from the beginning.
64
65  DYNAMIC: The input dataset is dynamically split among workers at runtime. Each
66  worker gets the next split when it reads data from the dispatcher. Data is
67  produced non-deterministically in this mode. Dynamic sharding works well with
68  varying-sized tf.data service clusters, e.g., when you need to auto-scale your
69  workers. Dynamic sharding provides at-most once visitation guarantees. No
70  examples will be repeated, but some may be missed if a tf.data service worker
71  gets restarted while processing a file.
72
73  The following are static sharding policies. The semantics are similar to
74  `tf.data.experimental.AutoShardPolicy`. These policies require:
75  * The tf.data service cluster is configured with a fixed list of workers
76    in DispatcherConfig.
77  * Each client only reads from the local tf.data service worker.
78
79  If a worker is restarted while performing static sharding, the worker will
80  begin processing its shard again from the beginning.
81
82  FILE: Shards by input files (i.e. each worker will get a fixed set of files to
83  process). When this option is selected, make sure that there is at least as
84  many files as workers. If there are fewer input files than workers, a runtime
85  error will be raised.
86
87  DATA: Shards by elements produced by the dataset. Each worker will process the
88  whole dataset and discard the portion that is not for itself. Note that for
89  this mode to correctly partition the dataset elements, the dataset needs to
90  produce elements in a deterministic order.
91
92  FILE_OR_DATA: Attempts FILE-based sharding, falling back to DATA-based
93  sharding on failure.
94
95  HINT: Looks for the presence of `shard(SHARD_HINT, ...)` which is treated as a
96  placeholder to replace with `shard(num_workers, worker_index)`.
97  """
98
99  # LINT.IfChange(tf_data_service_sharding_policy)
100  OFF = 0
101  DYNAMIC = 1
102  FILE = 2
103  DATA = 3
104  FILE_OR_DATA = 4
105  HINT = 5
106  # LINT.ThenChange()
107
108  def _to_proto(self):
109    """Converts the policy to ProcessingModeDef proto enum."""
110
111    if self == ShardingPolicy.OFF:
112      return data_service_pb2.ProcessingModeDef.OFF
113    if self == ShardingPolicy.DYNAMIC:
114      return data_service_pb2.ProcessingModeDef.DYNAMIC
115    if self == ShardingPolicy.FILE:
116      return data_service_pb2.ProcessingModeDef.FILE
117    if self == ShardingPolicy.DATA:
118      return data_service_pb2.ProcessingModeDef.DATA
119    if self == ShardingPolicy.FILE_OR_DATA:
120      return data_service_pb2.ProcessingModeDef.FILE_OR_DATA
121    if self == ShardingPolicy.HINT:
122      return data_service_pb2.ProcessingModeDef.HINT
123    raise ValueError(
124        f"Unable to convert sharding policy {self!r} to proto. Please verify "
125        "the policy mapping.")
126
127
128def _get_validated_sharding_policy(processing_mode):
129  """Validates `processing_mode` and converts it to ShardingPolicy."""
130
131  if isinstance(processing_mode, ShardingPolicy):
132    return processing_mode
133  if compat.forward_compatible(2021, 8, 24):
134    if processing_mode == _PARALLEL_EPOCHS:
135      return ShardingPolicy.OFF
136    if processing_mode == _DISTRIBUTED_EPOCH:
137      return ShardingPolicy.DYNAMIC
138  elif processing_mode in [_PARALLEL_EPOCHS, _DISTRIBUTED_EPOCH]:
139    return processing_mode
140
141  raise ValueError(
142      "tf.data service processing mode should be a ShardingPolicy, "
143      "`\"parallel_epochs\"`, or `\"distributed_epoch\"`. Got "
144      f"{processing_mode!r}.")
145
146
147def _serialize(processing_mode):
148  """Serializes `processing_mode`."""
149
150  processing_mode = _get_validated_sharding_policy(processing_mode)
151  if isinstance(processing_mode, ShardingPolicy):
152    # pylint: disable=protected-access
153    processing_mode_def = data_service_pb2.ProcessingModeDef(
154        sharding_policy=_get_validated_sharding_policy(
155            processing_mode)._to_proto())
156    return processing_mode_def.SerializeToString()
157  if processing_mode in [_PARALLEL_EPOCHS, _DISTRIBUTED_EPOCH]:
158    return processing_mode
159
160  raise ValueError(
161      "tf.data service processing mode should be a ShardingPolicy, "
162      "`\"parallel_epochs\"`, or `\"distributed_epoch\"`. Got "
163      f"{processing_mode!r}.")
164
165
166def _validate_job_name(job_name):
167  if job_name is None:
168    return
169  if not isinstance(job_name, six.string_types):
170    raise ValueError("job_name must be a string, but job_name was of type "
171                     "{0}. job_name={1}".format(type(job_name), job_name))
172  if not job_name:
173    raise ValueError("job_name must not be empty")
174
175
176class _DataServiceDatasetV2(dataset_ops.DatasetSource):
177  """A `Dataset` that reads elements from the tf.data service."""
178
179  def __init__(self,
180               dataset_id,
181               processing_mode,
182               address,
183               element_spec,
184               protocol,
185               data_transfer_protocol,
186               job_name=None,
187               consumer_index=None,
188               num_consumers=None,
189               max_outstanding_requests=None,
190               task_refresh_interval_hint_ms=None,
191               target_workers="AUTO"):
192    """Constructs a _DataServiceDatasetV2.
193
194    Args:
195      dataset_id: The dataset id for the dataset to read from.
196      processing_mode: A `tf.data.experimental.service.ShardingPolicy`
197        specifying how to shard the dataset among tf.data workers. See
198        `tf.data.experimental.service.ShardingPolicy` for details. For backwards
199        compatibility, `processing_mode` may also be set to the strings
200        `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
201        equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
202      address: The tf.data service address, e.g. "localhost:5000".
203      element_spec: The dataset element spec for the dataset to read from.
204      protocol: The protocol to use for communicating with the tf.data service,
205        e.g. "grpc".
206      data_transfer_protocol: (Optional.) The protocol to use for transferring
207        data with the tf.data service. By default, data is transferred using
208        gRPC.
209      job_name: (Optional.) The name of the job. If provided, it must be a
210        non-empty string or Tensor. This argument makes it possible
211        for multiple datasets to share the same job. The default behavior is
212        that the dataset creates anonymous, exclusively owned jobs.
213      consumer_index: (Optional.) The index of the consumer in the range from
214        `0` to `num_consumers`. Must be specified alongside `num_consumers`.
215        When specified, consumers will read from the job in a strict round-robin
216        order, instead of the default first-come-first-served order.
217      num_consumers: (Optional.) The number of consumers which will consume from
218        the job. Must be specified alongside `consumer_index`. When specified,
219        consumers will read from the job in a strict round-robin order, instead
220        of the default first-come-first-served order. When `num_consumers` is
221        specified, the dataset must have infinite cardinality to prevent a
222        producer from running out of data early and causing consumers to go out
223        of sync.
224      max_outstanding_requests: (Optional.) A limit on how many elements may be
225        requested at the same time. You can use this option to control the
226        amount of memory used, since `distribute` won't use more than
227        `element_size` * `max_outstanding_requests` of memory.
228      task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
229        the dispatcher for task changes.
230      target_workers: (Optional.) Which workers to read from. If `"AUTO"`,
231        tf.data runtime decides which workers to read from. If `"ANY"`, reads
232        from any tf.data service workers. If `"LOCAL"`, only reads from local
233        in-processs tf.data service workers. `"AUTO"` works well for most cases,
234        while users can specify other targets. For example, `"LOCAL"` helps
235        avoid RPCs and data copy if every TF worker colocates with a tf.data
236        service worker. Consumers of a shared job must use the same
237        `target_workers`. Defaults to `"AUTO"`.
238    """
239    processing_mode = _serialize(
240        _get_validated_sharding_policy(processing_mode))
241    if consumer_index is None != num_consumers is None:
242      raise ValueError(
243          "Must either set both consumer_index and num_consumers, or neither. ",
244          "consumer_index: ", consumer_index, ", num_consumers: ",
245          num_consumers)
246    if num_consumers is not None and job_name is None:
247      raise ValueError("job_name must be set when setting num_consumers")
248
249    if job_name is None:
250      job_name = ""
251    if max_outstanding_requests is None:
252      max_outstanding_requests = dataset_ops.AUTOTUNE
253    if task_refresh_interval_hint_ms is None:
254      task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE
255
256    self._dataset_id = ops.convert_to_tensor(
257        dataset_id, dtype=dtypes.int64, name="dataset_id")
258    self._processing_mode = ops.convert_to_tensor(
259        processing_mode, dtype=dtypes.string, name="processing_mode")
260    self._address = ops.convert_to_tensor(
261        address, dtype=dtypes.string, name="address")
262    self._protocol = ops.convert_to_tensor(
263        protocol, dtype=dtypes.string, name="protocol")
264    self._job_name = ops.convert_to_tensor(
265        job_name, dtype=dtypes.string, name="job_name")
266    self._consumer_index = ops.convert_to_tensor(
267        -1 if consumer_index is None else consumer_index,
268        dtype=dtypes.int64,
269        name="consumer_index")
270    self._num_consumers = ops.convert_to_tensor(
271        -1 if num_consumers is None else num_consumers,
272        dtype=dtypes.int64,
273        name="num_consumers")
274    self._max_outstanding_requests = ops.convert_to_tensor(
275        max_outstanding_requests,
276        dtype=dtypes.int64,
277        name="max_outstanding_requests")
278    self._element_spec = element_spec
279    self._target_workers = target_workers
280
281    compat_kwargs = {}
282    if data_transfer_protocol is not None:
283      compat_kwargs["data_transfer_protocol"] = data_transfer_protocol
284    if compat.forward_compatible(2021, 7, 12) or target_workers != "AUTO":
285      compat_kwargs["target_workers"] = target_workers
286
287    variant_tensor = gen_experimental_dataset_ops.data_service_dataset_v2(
288        dataset_id=self._dataset_id,
289        processing_mode=self._processing_mode,
290        address=self._address,
291        protocol=self._protocol,
292        job_name=self._job_name,
293        consumer_index=self._consumer_index,
294        num_consumers=self._num_consumers,
295        max_outstanding_requests=self._max_outstanding_requests,
296        task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
297        iteration_counter=gen_experimental_dataset_ops.dummy_iteration_counter(
298        ),
299        **compat_kwargs,
300        **self._flat_structure)
301    super(_DataServiceDatasetV2, self).__init__(variant_tensor)
302
303  @property
304  def element_spec(self):
305    return self._element_spec
306
307
308class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
309  """A `Dataset` that executes its input through the tf.data service."""
310
311  @functools.wraps(_DataServiceDatasetV2.__init__)
312  def __init__(self, dataset_id, processing_mode, address, element_spec,
313               protocol, data_transfer_protocol, job_name, consumer_index,
314               num_consumers, max_outstanding_requests,
315               task_refresh_interval_hint_ms, target_workers):
316
317    self._wrapped = _DataServiceDatasetV2(
318        dataset_id=dataset_id,
319        processing_mode=processing_mode,
320        address=address,
321        element_spec=element_spec,
322        protocol=protocol,
323        data_transfer_protocol=data_transfer_protocol,
324        job_name=job_name,
325        consumer_index=consumer_index,
326        num_consumers=num_consumers,
327        max_outstanding_requests=max_outstanding_requests,
328        task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
329        target_workers=target_workers)
330    super(_DataServiceDatasetV1, self).__init__(self._wrapped)
331
332
333if tf2.enabled():
334  _DataServiceDataset = _DataServiceDatasetV2
335else:
336  _DataServiceDataset = _DataServiceDatasetV1
337
338
339def _parse_service(service):
340  """Converts a tf.data service string into a (protocol, address) tuple.
341
342  Args:
343    service: A string in the format "protocol://address" or just "address". If
344      the string is only an address, the default protocol will be used.
345
346  Returns:
347    The (protocol, address) tuple
348  """
349  if not isinstance(service, six.string_types):
350    raise ValueError(
351        "service must be a string, but service was of type {0}. service={1}"
352        .format(type(service), service))
353  if not service:
354    raise ValueError("service must not be empty")
355  parts = service.split("://")
356  if len(parts) == 2:
357    protocol, address = parts
358  elif len(parts) == 1:
359    address = parts[0]
360    protocol = _pywrap_utils.TF_DATA_DefaultProtocol()
361  else:
362    raise ValueError("malformed service string has multiple '://': %s" %
363                     service)
364  # TODO(aaudibert): Considering validating reachability of address here.
365  return (protocol, address)
366
367
368def _distribute(processing_mode,
369                service,
370                job_name=None,
371                consumer_index=None,
372                num_consumers=None,
373                max_outstanding_requests=None,
374                task_refresh_interval_hint_ms=None,
375                data_transfer_protocol=None,
376                compression="AUTO",
377                target_workers="AUTO"):
378  """A transformation that moves dataset processing to the tf.data service.
379
380  This transformation is similar to `distribute`, but supports additional
381  parameters which we do not yet want to add to the public Python API.
382
383  Args:
384    processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying
385      how to shard the dataset among tf.data workers. See
386      `tf.data.experimental.service.ShardingPolicy` for details. For backwards
387      compatibility, `processing_mode` may also be set to the strings
388      `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
389      equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
390    service: A string or a tuple indicating how to connect to the tf.data
391      service. If it's a string, it should be in the format
392      `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
393      address and `<protocol>` can optionally be used to override the default
394      protocol to use. If it's a tuple, it should be (protocol, address).
395    job_name: (Optional.) The name of the job. If provided, it must be a
396      non-empty string. This argument makes it possible
397      for multiple datasets to share the same job. The default behavior is that
398      the dataset creates anonymous, exclusively owned jobs.
399    consumer_index: (Optional.) The index of the consumer in the range from `0`
400      to `num_consumers`. Must be specified alongside `num_consumers`. When
401      specified, consumers will read from the job in a strict round-robin order,
402      instead of the default first-come-first-served order.
403    num_consumers: (Optional.) The number of consumers which will consume from
404      the job. Must be specified alongside `consumer_index`. When specified,
405      consumers will read from the job in a strict round-robin order, instead of
406      the default first-come-first-served order. When `num_consumers` is
407      specified, the dataset must have infinite cardinality to prevent a
408      producer from running out of data early and causing consumers to go out of
409      sync.
410    max_outstanding_requests: (Optional.) A limit on how many elements may be
411      requested at the same time. You can use this option to control the amount
412      of memory used, since `distribute` won't use more than `element_size` *
413      `max_outstanding_requests` of memory.
414    task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
415      dispatcher for task changes.
416    data_transfer_protocol: (Optional.) The protocol to use for transferring
417      data with the tf.data service. By default, data is transferred using gRPC.
418    compression: How to compress the dataset's elements before transferring them
419      over the network. "AUTO" leaves the decision of how to compress up to the
420      tf.data service runtime. `None` indicates not to compress.
421    target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data
422      runtime decides which workers to read from. If `"ANY"`, reads from any
423      tf.data service workers. If `"LOCAL"`, only reads from local in-processs
424      tf.data service workers. `"AUTO"` works well for most cases, while users
425      can specify other targets. For example, `"LOCAL"` helps avoid RPCs and
426      data copy if every TF worker colocates with a tf.data service worker.
427      Consumers of a shared job must use the same `target_workers`. Defaults
428      to `"AUTO"`.
429
430  Returns:
431    Dataset: A `Dataset` of the elements produced by the data service.
432  """
433  processing_mode = _get_validated_sharding_policy(processing_mode)
434  valid_compressions = [COMPRESSION_AUTO, COMPRESSION_NONE]
435  if compression not in valid_compressions:
436    raise ValueError(
437        "Invalid compression argument: {}. Must be one of {}".format(
438            compression, valid_compressions))
439  if compression == COMPRESSION_AUTO and data_transfer_protocol is not None:
440    compression = COMPRESSION_NONE
441  def _apply_fn(dataset):  # pylint: disable=missing-docstring
442    dataset_id = _register_dataset(service, dataset, compression=compression)
443    return _from_dataset_id(
444        processing_mode,
445        service,
446        dataset_id,
447        dataset.element_spec,
448        job_name=job_name,
449        consumer_index=consumer_index,
450        num_consumers=num_consumers,
451        max_outstanding_requests=max_outstanding_requests,
452        task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
453        data_transfer_protocol=data_transfer_protocol,
454        compression=compression,
455        target_workers=target_workers)
456
457  return _apply_fn
458
459
460@tf_export("data.experimental.service.distribute")
461def distribute(processing_mode,
462               service,
463               job_name=None,
464               consumer_index=None,
465               num_consumers=None,
466               max_outstanding_requests=None,
467               data_transfer_protocol=None,
468               compression="AUTO",
469               target_workers="AUTO"):
470  """A transformation that moves dataset processing to the tf.data service.
471
472  When you iterate over a dataset containing the `distribute` transformation,
473  the tf.data service creates a "job" which produces data for the dataset
474  iteration.
475
476  The tf.data service uses a cluster of workers to prepare data for training
477  your model.
478  The `processing_mode` argument to `tf.data.experimental.service.distribute`
479  describes how to leverage multiple workers to process the input dataset.
480  Currently, there are two processing modes to choose from: "distributed_epoch"
481  and "parallel_epochs".
482
483  "distributed_epoch" means that the dataset will be split across all tf.data
484  service workers.
485  The dispatcher produces "splits" for the dataset and sends them to workers for
486  further processing. For example, if a dataset begins with a list of filenames,
487  the dispatcher will iterate through the filenames and send the filenames to
488  tf.data workers, which will perform the rest of the dataset transformations on
489  those files. "distributed_epoch" is useful when your model needs to see each
490  element of the dataset exactly once, or if it needs to see the data in a
491  generally-sequential order. "distributed_epoch" only works for datasets with
492  splittable sources, such as `Dataset.from_tensor_slices`,
493  `Dataset.list_files`, or `Dataset.range`.
494
495  "parallel_epochs" means that the entire input dataset will be processed
496  independently by each of the tf.data service workers.
497  For this reason, it is important to shuffle data (e.g. filenames)
498  non-deterministically, so that each worker will process the elements of the
499  dataset in a different order. "parallel_epochs" can be used to distribute
500  datasets that aren't splittable.
501
502  With two workers, "parallel_epochs" will produce every element of the dataset
503  twice:
504
505  >>> dispatcher = tf.data.experimental.service.DispatchServer()
506  >>> dispatcher_address = dispatcher.target.split("://")[1]
507  >>> # Start two workers
508  >>> workers = [
509  ...     tf.data.experimental.service.WorkerServer(
510  ...         tf.data.experimental.service.WorkerConfig(
511  ...             dispatcher_address=dispatcher_address)) for _ in range(2)
512  ... ]
513  >>> dataset = tf.data.Dataset.range(10)
514  >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
515  ...     processing_mode="parallel_epochs", service=dispatcher.target))
516  >>> print(sorted(list(dataset.as_numpy_iterator())))
517  [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9]
518
519  "distributed_epoch", on the other hand, will still produce each element once:
520
521  >>> dispatcher = tf.data.experimental.service.DispatchServer()
522  >>> dispatcher_address = dispatcher.target.split("://")[1]
523  >>> workers = [
524  ...     tf.data.experimental.service.WorkerServer(
525  ...         tf.data.experimental.service.WorkerConfig(
526  ...             dispatcher_address=dispatcher_address)) for _ in range(2)
527  ... ]
528  >>> dataset = tf.data.Dataset.range(10)
529  >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
530  ...     processing_mode="distributed_epoch", service=dispatcher.target))
531  >>> print(sorted(list(dataset.as_numpy_iterator())))
532  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
533
534  When using `apply(tf.data.experimental.service.distribute(...))`, the dataset
535  before the `apply` transformation executes within the tf.data service, while
536  the operations after `apply` happen within the local process.
537
538  >>> dispatcher = tf.data.experimental.service.DispatchServer()
539  >>> dispatcher_address = dispatcher.target.split("://")[1]
540  >>> workers = [
541  ...     tf.data.experimental.service.WorkerServer(
542  ...         tf.data.experimental.service.WorkerConfig(
543  ...             dispatcher_address=dispatcher_address)) for _ in range(2)
544  ... ]
545  >>> dataset = tf.data.Dataset.range(5)
546  >>> dataset = dataset.map(lambda x: x*x)
547  >>> dataset = dataset.apply(
548  ...    tf.data.experimental.service.distribute("parallel_epochs",
549  ...                                            dispatcher.target))
550  >>> dataset = dataset.map(lambda x: x+1)
551  >>> print(sorted(list(dataset.as_numpy_iterator())))
552  [1, 1, 2, 2, 5, 5, 10, 10, 17, 17]
553
554  In the above example, the dataset operations (before applying the `distribute`
555  function on the elements) will be executed on the tf.data workers,
556  and the elements are provided over RPC. The remaining transformations
557  (after the call to `distribute`) will be executed locally. The dispatcher
558  and the workers will bind to usused free ports (which are chosen at random),
559  in order to communicate with each other. However, to bind them to specific
560  ports, the `port` parameter can be passed.
561
562  The `job_name` argument allows jobs to be shared across multiple
563  datasets. Instead of each dataset creating its own job, all
564  datasets with the same `job_name` will consume from the same job. A new job
565  will be created for each iteration of the dataset (with each repetition of
566  `Dataset.repeat` counting as a new iteration). Suppose the `DispatchServer`
567  is serving on `localhost:5000` and two training workers (in either a single
568  client or multi-client setup) iterate over the below dataset, and there is a
569  single tf.data worker:
570
571  ```
572  range5_dataset = tf.data.Dataset.range(5)
573  dataset = range5_dataset.apply(tf.data.experimental.service.distribute(
574      "parallel_epochs", "localhost:5000", job_name="my_job_name"))
575  for iteration in range(3):
576    print(list(dataset))
577  ```
578
579  The elements of each job will be split between the two processes, with
580  elements being consumed by the processes on a first-come first-served basis.
581  One possible result is that process 1 prints
582
583  ```
584  [0, 2, 4]
585  [0, 1, 3]
586  [1]
587  ```
588
589  and process 2 prints
590
591  ```
592  [1, 3]
593  [2, 4]
594  [0, 2, 3, 4]
595  ```
596
597  Job names must not be re-used across different training jobs within the
598  lifetime of the tf.data service. In general, the tf.data service is expected
599  to live for the duration of a single training job.
600  To use the tf.data service with multiple training jobs, make sure to use
601  different job names to avoid conflicts. For example, suppose a training job
602  calls `distribute` with `job_name="job"` and reads until end of input. If
603  another independent job connects to the same tf.data service and tries to read
604  from `job_name="job"`, it will immediately receive end of input, without
605  getting any data.
606
607  **Round Robin data consumption**
608
609  By default, when multiple consumers read from the same job, they receive data
610  on a first-come first-served basis. In some use cases, it works better to use
611  a strict round-robin order. For example, the tf.data service can be used to
612  coordinate example sizes across a cluster during sychronous training, so that
613  during each step all replicas train on similar-sized elements. To achieve
614  this, define a dataset which generates rounds of `num_consumers` consecutive
615  similar-sized batches, then enable round-robin reads by setting
616  `consumer_index` and `num_consumers`.
617
618  Consumers read data by cycling through all workers, reading one element from
619  each. First, each consumer will read an element from the first worker, then
620  each consumer will read an element from the second worker, and so on.
621
622  NOTE: To keep consumers in sync, round robin data consumption requires that
623  the dataset have infinite cardinality. You can get this by adding `.repeat()`
624  at the end of the dataset definition.
625
626  **Keras and Distribution Strategies**
627
628  The dataset produced by the `distribute` transformation can be passed to
629  Keras' `Model.fit` or Distribution Strategy's
630  `tf.distribute.Strategy.experimental_distribute_dataset` like any other
631  `tf.data.Dataset`. We recommend setting a `job_name` on the call to
632  `distribute` so that if there are multiple workers, they read data from the
633  same job. Note that the autosharding normally performed by
634  `experimental_distribute_dataset` will be disabled when setting a `job_name`,
635  since sharing the job already results in splitting data across the workers.
636  When using a shared job, data will be dynamically balanced across workers, so
637  that they reach end of input about the same time. This results in better
638  worker utilization than with autosharding, where each worker processes an
639  independent set of files, and some workers may run out of data earlier than
640  others.
641
642  Args:
643    processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying
644      how to shard the dataset among tf.data workers. See
645      `tf.data.experimental.service.ShardingPolicy` for details. For backwards
646      compatibility, `processing_mode` may also be set to the strings
647      `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
648      equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
649    service: A string or a tuple indicating how to connect to the tf.data
650      service. If it's a string, it should be in the format
651      `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
652      address and `<protocol>` can optionally be used to override the default
653      protocol to use. If it's a tuple, it should be (protocol, address).
654    job_name: (Optional.) The name of the job. If provided, it must be a
655      non-empty string. This argument makes it possible
656      for multiple datasets to share the same job. The default behavior is that
657      the dataset creates anonymous, exclusively owned jobs.
658    consumer_index: (Optional.) The index of the consumer in the range from `0`
659      to `num_consumers`. Must be specified alongside `num_consumers`. When
660      specified, consumers will read from the job in a strict round-robin order,
661      instead of the default first-come-first-served order.
662    num_consumers: (Optional.) The number of consumers which will consume from
663      the job. Must be specified alongside `consumer_index`. When specified,
664      consumers will read from the job in a strict round-robin order, instead of
665      the default first-come-first-served order. When `num_consumers` is
666      specified, the dataset must have infinite cardinality to prevent a
667      producer from running out of data early and causing consumers to go out of
668      sync.
669    max_outstanding_requests: (Optional.) A limit on how many elements may be
670      requested at the same time. You can use this option to control the amount
671      of memory used, since `distribute` won't use more than `element_size` *
672      `max_outstanding_requests` of memory.
673    data_transfer_protocol: (Optional.) The protocol to use for transferring
674      data with the tf.data service. By default, data is transferred using gRPC.
675    compression: How to compress the dataset's elements before transferring them
676      over the network. "AUTO" leaves the decision of how to compress up to the
677      tf.data service runtime. `None` indicates not to compress.
678    target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data
679      runtime decides which workers to read from. If `"ANY"`, reads from any
680      tf.data service workers. If `"LOCAL"`, only reads from local in-processs
681      tf.data service workers. `"AUTO"` works well for most cases, while users
682      can specify other targets. For example, `"LOCAL"` helps avoid RPCs and
683      data copy if every TF worker colocates with a tf.data service worker.
684      Consumers of a shared job must use the same `target_workers`. Defaults
685      to `"AUTO"`.
686
687  Returns:
688    Dataset: A `Dataset` of the elements produced by the data service.
689  """
690  _validate_job_name(job_name)
691  return _distribute(
692      processing_mode=processing_mode,
693      service=service,
694      job_name=job_name,
695      consumer_index=consumer_index,
696      num_consumers=num_consumers,
697      max_outstanding_requests=max_outstanding_requests,
698      data_transfer_protocol=data_transfer_protocol,
699      compression=compression,
700      target_workers=target_workers)
701
702
703def _register_dataset(service, dataset, compression):
704  """Registers a dataset with the tf.data service.
705
706  This transformation is similar to `register_dataset`, but supports additional
707  parameters which we do not yet want to add to the public Python API.
708
709  Args:
710    service: A string or a tuple indicating how to connect to the tf.data
711      service. If it's a string, it should be in the format
712      `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
713      address and `<protocol>` can optionally be used to override the default
714      protocol to use. If it's a tuple, it should be (protocol, address).
715    dataset: A `tf.data.Dataset` to register with the tf.data service.
716    compression: How to compress the dataset's elements before transferring them
717      over the network. "AUTO" leaves the decision of how to compress up to the
718      tf.data service runtime. `None` indicates not to compress.
719
720  Returns:
721    A scalar int64 tensor of the registered dataset's id.
722  """
723  valid_compressions = [COMPRESSION_AUTO, COMPRESSION_NONE]
724  if compression not in valid_compressions:
725    raise ValueError(
726        "Invalid compression argument: {}. Must be one of {}".format(
727            compression, valid_compressions))
728  if isinstance(service, tuple):
729    protocol, address = service
730  else:
731    protocol, address = _parse_service(service)
732  external_state_policy = dataset.options().experimental_external_state_policy
733  if external_state_policy is None:
734    external_state_policy = ExternalStatePolicy.WARN
735
736  encoded_spec = ""
737  if context.executing_eagerly():
738    coder = nested_structure_coder.StructureCoder()
739    encoded_spec = coder.encode_structure(
740        dataset.element_spec).SerializeToString()
741
742  if compression == COMPRESSION_AUTO:
743    dataset = dataset.map(
744        lambda *x: compression_ops.compress(x),
745        num_parallel_calls=dataset_ops.AUTOTUNE)
746  dataset = dataset.prefetch(dataset_ops.AUTOTUNE)
747  dataset = dataset._apply_debug_options()  # pylint: disable=protected-access
748
749  dataset_id = gen_experimental_dataset_ops.register_dataset(
750      dataset._variant_tensor,  # pylint: disable=protected-access
751      address=address,
752      protocol=protocol,
753      external_state_policy=external_state_policy.value,
754      element_spec=encoded_spec)
755
756  return dataset_id
757
758
759@tf_export("data.experimental.service.register_dataset")
760def register_dataset(service, dataset):
761  """Registers a dataset with the tf.data service.
762
763  `register_dataset` registers a dataset with the tf.data service so that
764  datasets can be created later with
765  `tf.data.experimental.service.from_dataset_id`. This is useful when the
766  dataset
767  is registered by one process, then used in another process. When the same
768  process is both registering and reading from the dataset, it is simpler to use
769  `tf.data.experimental.service.distribute` instead.
770
771  If the dataset is already registered with the tf.data service,
772  `register_dataset` returns the already-registered dataset's id.
773
774  >>> dispatcher = tf.data.experimental.service.DispatchServer()
775  >>> dispatcher_address = dispatcher.target.split("://")[1]
776  >>> worker = tf.data.experimental.service.WorkerServer(
777  ...     tf.data.experimental.service.WorkerConfig(
778  ...         dispatcher_address=dispatcher_address))
779  >>> dataset = tf.data.Dataset.range(10)
780  >>> dataset_id = tf.data.experimental.service.register_dataset(
781  ...     dispatcher.target, dataset)
782  >>> dataset = tf.data.experimental.service.from_dataset_id(
783  ...     processing_mode="parallel_epochs",
784  ...     service=dispatcher.target,
785  ...     dataset_id=dataset_id,
786  ...     element_spec=dataset.element_spec)
787  >>> print(list(dataset.as_numpy_iterator()))
788  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
789
790  Args:
791    service: A string or a tuple indicating how to connect to the tf.data
792      service. If it's a string, it should be in the format
793      `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
794      address and `<protocol>` can optionally be used to override the default
795      protocol to use. If it's a tuple, it should be (protocol, address).
796    dataset: A `tf.data.Dataset` to register with the tf.data service.
797
798  Returns:
799    A scalar int64 tensor of the registered dataset's id.
800  """
801  return _register_dataset(service, dataset, compression="AUTO")
802
803
804def _from_dataset_id(processing_mode,
805                     service,
806                     dataset_id,
807                     element_spec,
808                     job_name=None,
809                     consumer_index=None,
810                     num_consumers=None,
811                     max_outstanding_requests=None,
812                     task_refresh_interval_hint_ms=None,
813                     data_transfer_protocol=None,
814                     compression="AUTO",
815                     target_workers="AUTO"):
816  """Creates a dataset which reads data from the tf.data service.
817
818  This transformation is similar to `from_dataset_id`, but supports additional
819  parameters which we do not yet want to add to the public Python API.
820
821  Args:
822    processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying
823      how to shard the dataset among tf.data workers. See
824      `tf.data.experimental.service.ShardingPolicy` for details. For backwards
825      compatibility, `processing_mode` may also be set to the strings
826      `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
827      equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
828    service: A string or a tuple indicating how to connect to the tf.data
829      service. If it's a string, it should be in the format
830      `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
831      address and `<protocol>` can optionally be used to override the default
832      protocol to use. If it's a tuple, it should be (protocol, address).
833    dataset_id: The id of the dataset to read from. This id is returned by
834      `register_dataset` when the dataset is registered with the tf.data
835      service.
836    element_spec: A nested structure of `tf.TypeSpec`s representing the type of
837      elements produced by the dataset. This argument is only required inside a
838      tf.function. Use `tf.data.Dataset.element_spec` to get the element spec
839      for a given dataset.
840    job_name: (Optional.) The name of the job. If provided, it must be a
841      non-empty string or tensor. This argument makes it possible
842      for multiple datasets to share the same job. The default behavior is that
843      the dataset creates anonymous, exclusively owned jobs.
844    consumer_index: (Optional.) The index of the consumer in the range from `0`
845      to `num_consumers`. Must be specified alongside `num_consumers`. When
846      specified, consumers will read from the job in a strict round-robin order,
847      instead of the default first-come-first-served order.
848    num_consumers: (Optional.) The number of consumers which will consume from
849      the job. Must be specified alongside `consumer_index`. When specified,
850      consumers will read from the job in a strict round-robin order, instead of
851      the default first-come-first-served order. When `num_consumers` is
852      specified, the dataset must have infinite cardinality to prevent a
853      producer from running out of data early and causing consumers to go out of
854      sync.
855    max_outstanding_requests: (Optional.) A limit on how many elements may be
856      requested at the same time. You can use this option to control the amount
857      of memory used, since `distribute` won't use more than `element_size` *
858      `max_outstanding_requests` of memory.
859    task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
860      dispatcher for task changes.
861    data_transfer_protocol: (Optional.) The protocol to use for transferring
862      data with the tf.data service. By default, data is transferred using gRPC.
863    compression: An indication of how the dataset's elements were compressed, so
864      that `from_dataset_id` can uncompress them if necessary.
865    target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data
866      runtime decides which workers to read from. If `"ANY"`, reads from any
867      tf.data service workers. If `"LOCAL"`, only reads from local in-processs
868      tf.data service workers. `"AUTO"` works well for most cases, while users
869      can specify other targets. For example, `"LOCAL"` helps avoid RPCs and
870      data copy if every TF worker colocates with a tf.data service worker.
871      Consumers of a shared job must use the same `target_workers`. Defaults
872      to `"AUTO"`.
873
874  Returns:
875    A `tf.data.Dataset` which reads from the tf.data service.
876  """
877  processing_mode = _get_validated_sharding_policy(processing_mode)
878  valid_compressions = [COMPRESSION_AUTO, COMPRESSION_NONE]
879  if isinstance(service, tuple):
880    protocol, address = service
881  else:
882    protocol, address = _parse_service(service)
883
884  if compression not in valid_compressions:
885    raise ValueError(
886        "Invalid compression argument: {}. Must be one of {}".format(
887            compression, valid_compressions))
888  if job_name is not None:
889    if not isinstance(job_name, six.string_types) and not isinstance(
890        job_name, ops.Tensor):
891      raise ValueError(
892          "job_name must be a string or Tensor, but job_name was of type "
893          "{0}. job_name={1}".format(type(job_name), job_name))
894
895  if element_spec is None:
896    if not context.executing_eagerly():
897      raise ValueError("In graph mode element_spec must be provided manually.")
898
899    dataset_id_val = tensor_util.constant_value(dataset_id)
900    try:
901      encoded_spec = _pywrap_server_lib.TF_DATA_GetElementSpec(
902          dataset_id_val, address, protocol)
903
904    except NotImplementedError as err:
905      raise ValueError("The tf.data service is running an earlier version of "
906                       "TensorFlow that requires specifying `element_spec` as "
907                       "an argument to `from_dataset_id`. Please either supply "
908                       "an element spec or update the tf.data service to the "
909                       "latest version.") from err
910
911    except RuntimeError as err:
912      raise ValueError("Failed to fetch element spec for dataset id " +
913                       str(dataset_id_val) + " from tf.data service. If the "
914                       "dataset was registered in graph mode or inside a "
915                       "tf.function, the `element_spec` must be specified as "
916                       "an argument to `from_dataset_id`.") from err
917
918    struct_pb = nested_structure_coder.struct_pb2.StructuredValue()
919    struct_pb.ParseFromString(encoded_spec)
920    coder = nested_structure_coder.StructureCoder()
921    element_spec = coder.decode_proto(struct_pb)
922
923  # If we compress, the data service side dataset will produce scalar variants.
924  data_service_element_spec = (
925      tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
926      if compression == COMPRESSION_AUTO else element_spec)
927
928  dataset = _DataServiceDataset(
929      dataset_id=dataset_id,
930      processing_mode=processing_mode,
931      address=address,
932      element_spec=data_service_element_spec,
933      protocol=protocol,
934      data_transfer_protocol=data_transfer_protocol,
935      job_name=job_name,
936      consumer_index=consumer_index,
937      num_consumers=num_consumers,
938      max_outstanding_requests=max_outstanding_requests,
939      task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
940      target_workers=target_workers)
941  if compression == COMPRESSION_AUTO:
942    dataset = dataset.map(
943        lambda x: compression_ops.uncompress(x, output_spec=element_spec),
944        num_parallel_calls=dataset_ops.AUTOTUNE)
945
946  # Disable autosharding for shared jobs.
947  if job_name is not None:
948    options = options_lib.Options()
949    options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
950    dataset = dataset.with_options(options)
951  return dataset
952
953
954@tf_export("data.experimental.service.from_dataset_id")
955def from_dataset_id(processing_mode,
956                    service,
957                    dataset_id,
958                    element_spec=None,
959                    job_name=None,
960                    consumer_index=None,
961                    num_consumers=None,
962                    max_outstanding_requests=None,
963                    data_transfer_protocol=None,
964                    target_workers="AUTO"):
965  """Creates a dataset which reads data from the tf.data service.
966
967  This is useful when the dataset is registered by one process, then used in
968  another process. When the same process is both registering and reading from
969  the dataset, it is simpler to use `tf.data.experimental.service.distribute`
970  instead.
971
972  Before using `from_dataset_id`, the dataset must have been registered with the
973  tf.data service using `tf.data.experimental.service.register_dataset`.
974  `register_dataset` returns a dataset id for the registered dataset. That is
975  the `dataset_id` which should be passed to `from_dataset_id`.
976
977  The `element_spec` argument indicates the `tf.TypeSpec`s for the elements
978  produced by the dataset. Currently `element_spec` must be explicitly
979  specified, and match the dataset registered under `dataset_id`. `element_spec`
980  defaults to `None` so that in the future we can support automatically
981  discovering the `element_spec` by querying the tf.data service.
982
983  `tf.data.experimental.service.distribute` is a convenience method which
984  combines `register_dataset` and `from_dataset_id` into a dataset
985  transformation.
986  See the documentation for `tf.data.experimental.service.distribute` for more
987  detail about how `from_dataset_id` works.
988
989  >>> dispatcher = tf.data.experimental.service.DispatchServer()
990  >>> dispatcher_address = dispatcher.target.split("://")[1]
991  >>> worker = tf.data.experimental.service.WorkerServer(
992  ...     tf.data.experimental.service.WorkerConfig(
993  ...         dispatcher_address=dispatcher_address))
994  >>> dataset = tf.data.Dataset.range(10)
995  >>> dataset_id = tf.data.experimental.service.register_dataset(
996  ...     dispatcher.target, dataset)
997  >>> dataset = tf.data.experimental.service.from_dataset_id(
998  ...     processing_mode="parallel_epochs",
999  ...     service=dispatcher.target,
1000  ...     dataset_id=dataset_id,
1001  ...     element_spec=dataset.element_spec)
1002  >>> print(list(dataset.as_numpy_iterator()))
1003  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
1004
1005  Args:
1006    processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying
1007      how to shard the dataset among tf.data workers. See
1008      `tf.data.experimental.service.ShardingPolicy` for details. For backwards
1009      compatibility, `processing_mode` may also be set to the strings
1010      `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
1011      equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
1012    service: A string or a tuple indicating how to connect to the tf.data
1013      service. If it's a string, it should be in the format
1014      `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
1015      address and `<protocol>` can optionally be used to override the default
1016      protocol to use. If it's a tuple, it should be (protocol, address).
1017    dataset_id: The id of the dataset to read from. This id is returned by
1018      `register_dataset` when the dataset is registered with the tf.data
1019      service.
1020    element_spec: A nested structure of `tf.TypeSpec`s representing the type of
1021      elements produced by the dataset. This argument is only required inside a
1022      tf.function. Use `tf.data.Dataset.element_spec` to get the element spec
1023      for a given dataset.
1024    job_name: (Optional.) The name of the job. If provided, it must be a
1025      non-empty string. This argument makes it possible
1026      for multiple datasets to share the same job. The default behavior is that
1027      the dataset creates anonymous, exclusively owned jobs.
1028    consumer_index: (Optional.) The index of the consumer in the range from `0`
1029      to `num_consumers`. Must be specified alongside `num_consumers`. When
1030      specified, consumers will read from the job in a strict round-robin order,
1031      instead of the default first-come-first-served order.
1032    num_consumers: (Optional.) The number of consumers which will consume from
1033      the job. Must be specified alongside `consumer_index`. When specified,
1034      consumers will read from the job in a strict round-robin order, instead of
1035      the default first-come-first-served order. When `num_consumers` is
1036      specified, the dataset must have infinite cardinality to prevent a
1037      producer from running out of data early and causing consumers to go out of
1038      sync.
1039    max_outstanding_requests: (Optional.) A limit on how many elements may be
1040      requested at the same time. You can use this option to control the amount
1041      of memory used, since `distribute` won't use more than `element_size` *
1042      `max_outstanding_requests` of memory.
1043    data_transfer_protocol: (Optional.) The protocol to use for transferring
1044      data with the tf.data service. By default, data is transferred using gRPC.
1045    target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data
1046      runtime decides which workers to read from. If `"ANY"`, reads from any
1047      tf.data service workers. If `"LOCAL"`, only reads from local in-processs
1048      tf.data service workers. `"AUTO"` works well for most cases, while users
1049      can specify other targets. For example, `"LOCAL"` helps avoid RPCs and
1050      data copy if every TF worker colocates with a tf.data service worker.
1051      Consumers of a shared job must use the same `target_workers`. Defaults
1052      to `"AUTO"`.
1053
1054  Returns:
1055    A `tf.data.Dataset` which reads from the tf.data service.
1056  """
1057  _validate_job_name(job_name)
1058  if job_name is not None:
1059    job_name = string_ops.string_join(
1060        ["dataset_id=", string_ops.as_string(dataset_id), job_name], "/")
1061
1062  return _from_dataset_id(
1063      processing_mode=processing_mode,
1064      service=service,
1065      dataset_id=dataset_id,
1066      element_spec=element_spec,
1067      job_name=job_name,
1068      consumer_index=consumer_index,
1069      num_consumers=num_consumers,
1070      max_outstanding_requests=max_outstanding_requests,
1071      data_transfer_protocol=data_transfer_protocol,
1072      target_workers=target_workers)
1073