• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""API for specifying `tf.data` options."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import enum
22
23from absl import logging
24
25from tensorflow.core.framework import dataset_options_pb2
26from tensorflow.python.data.util import options as options_lib
27from tensorflow.python.util import deprecation
28from tensorflow.python.util.tf_export import tf_export
29
30
31@tf_export("data.experimental.AutoShardPolicy")
32class AutoShardPolicy(enum.IntEnum):
33  """Represents the type of auto-sharding to use.
34
35  OFF: No sharding will be performed.
36
37  AUTO: Attempts FILE-based sharding, falling back to DATA-based sharding.
38
39  FILE: Shards by input files (i.e. each worker will get a set of files to
40  process). When this option is selected, make sure that there is at least as
41  many files as workers. If there are fewer input files than workers, a runtime
42  error will be raised.
43
44  DATA: Shards by elements produced by the dataset. Each worker will process the
45  whole dataset and discard the portion that is not for itself. Note that for
46  this mode to correctly partitions the dataset elements, the dataset needs to
47  produce elements in a deterministic order.
48
49  HINT: Looks for the presence of `shard(SHARD_HINT, ...)` which is treated as a
50  placeholder to replace with `shard(num_workers, worker_index)`.
51  """
52
53  # LINT.IfChange
54  OFF = -1
55  AUTO = 0
56  FILE = 1
57  DATA = 2
58  HINT = 3
59  # LINT.ThenChange(//tensorflow/python/data/experimental/ops/data_service_ops.py:tf_data_service_sharding_policy)
60
61  @classmethod
62  def _to_proto(cls, obj):
63    """Convert enum to proto."""
64    if obj == cls.OFF:
65      return dataset_options_pb2.AutoShardPolicy.OFF
66    if obj == cls.FILE:
67      return dataset_options_pb2.AutoShardPolicy.FILE
68    if obj == cls.DATA:
69      return dataset_options_pb2.AutoShardPolicy.DATA
70    if obj == cls.AUTO:
71      return dataset_options_pb2.AutoShardPolicy.AUTO
72    if obj == cls.HINT:
73      return dataset_options_pb2.AutoShardPolicy.HINT
74    raise ValueError("%s._to_proto() is called with undefined enum %s." %
75                     (cls.__name__, obj.name))
76
77  @classmethod
78  def _from_proto(cls, pb):
79    """Convert proto to enum."""
80    if pb == dataset_options_pb2.AutoShardPolicy.OFF:
81      return cls.OFF
82    if pb == dataset_options_pb2.AutoShardPolicy.FILE:
83      return cls.FILE
84    if pb == dataset_options_pb2.AutoShardPolicy.DATA:
85      return cls.DATA
86    if pb == dataset_options_pb2.AutoShardPolicy.AUTO:
87      return cls.AUTO
88    if pb == dataset_options_pb2.AutoShardPolicy.HINT:
89      return cls.HINT
90    raise ValueError("%s._from_proto() is called with undefined enum %s." %
91                     (cls.__name__, pb))
92
93
94@tf_export("data.experimental.ExternalStatePolicy")
95class ExternalStatePolicy(enum.Enum):
96  """Represents how to handle external state during serialization.
97
98  See the `tf.data.Options.experimental_external_state_policy` documentation
99  for more information.
100  """
101  WARN = 0
102  IGNORE = 1
103  FAIL = 2
104
105  @classmethod
106  def _to_proto(cls, obj):
107    """Convert enum to proto."""
108    if obj == cls.IGNORE:
109      return dataset_options_pb2.ExternalStatePolicy.POLICY_IGNORE
110    if obj == cls.FAIL:
111      return dataset_options_pb2.ExternalStatePolicy.POLICY_FAIL
112    if obj == cls.WARN:
113      return dataset_options_pb2.ExternalStatePolicy.POLICY_WARN
114    raise ValueError("%s._to_proto() is called with undefined enum %s." %
115                     (cls.__name__, obj.name))
116
117  @classmethod
118  def _from_proto(cls, pb):
119    """Convert proto to enum."""
120    if pb == dataset_options_pb2.ExternalStatePolicy.POLICY_IGNORE:
121      return cls.IGNORE
122    if pb == dataset_options_pb2.ExternalStatePolicy.POLICY_FAIL:
123      return cls.FAIL
124    if pb == dataset_options_pb2.ExternalStatePolicy.POLICY_WARN:
125      return cls.WARN
126    raise ValueError("%s._from_proto() is called with undefined enum %s." %
127                     (cls.__name__, pb))
128
129
130@tf_export("data.experimental.AutotuneOptions")
131class AutotuneOptions(options_lib.OptionsBase):
132  """Represents options for autotuning dataset performance.
133
134  ```python
135  options = tf.data.Options()
136  options.autotune.enabled = False
137  dataset = dataset.with_options(options)
138  ```
139  """
140
141  enabled = options_lib.create_option(
142      name="enabled",
143      ty=bool,
144      docstring="Whether to automatically tune performance knobs. If None, "
145      "defaults to True.")
146
147  cpu_budget = options_lib.create_option(
148      name="cpu_budget",
149      ty=int,
150      docstring="When autotuning is enabled (through `autotune`), determines "
151      "the CPU budget to use. Values greater than the number of schedulable "
152      "CPU cores are allowed but may result in CPU contention. If None, "
153      "defaults to the number of schedulable CPU cores.")
154
155  ram_budget = options_lib.create_option(
156      name="ram_budget",
157      ty=int,
158      docstring="When autotuning is enabled (through `autotune`), determines "
159      "the RAM budget to use. Values greater than the available RAM in bytes "
160      "may result in OOM. If None, defaults to half of the available RAM in "
161      "bytes.")
162
163  def _to_proto(self):
164    pb = dataset_options_pb2.AutotuneOptions()
165    if self.enabled is not None:
166      pb.enabled = self.enabled
167    if self.cpu_budget is not None:
168      pb.cpu_budget = self.cpu_budget
169    if self.ram_budget is not None:
170      pb.ram_budget = self.ram_budget
171    return pb
172
173  def _from_proto(self, pb):
174    if pb.WhichOneof("optional_enabled") is not None:
175      self.enabled = pb.enabled
176    if pb.WhichOneof("optional_cpu_budget") is not None:
177      self.cpu_budget = pb.cpu_budget
178    if pb.WhichOneof("optional_ram_budget") is not None:
179      self.ram_budget = pb.ram_budget
180
181  def _set_mutable(self, mutable):
182    """Change the mutability value to `mutable` on this options and children."""
183    # pylint: disable=protected-access
184    object.__setattr__(self, "_mutable", mutable)
185
186
187@tf_export("data.experimental.DistributeOptions")
188class DistributeOptions(options_lib.OptionsBase):
189  """Represents options for distributed data processing.
190
191  You can set the distribution options of a dataset through the
192  `experimental_distribute` property of `tf.data.Options`; the property is
193  an instance of `tf.data.experimental.DistributeOptions`.
194
195  ```python
196  options = tf.data.Options()
197  options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
198  dataset = dataset.with_options(options)
199  ```
200  """
201
202  auto_shard_policy = options_lib.create_option(
203      name="auto_shard_policy",
204      ty=AutoShardPolicy,
205      docstring="The type of sharding to use. See "
206      "`tf.data.experimental.AutoShardPolicy` for additional information.",
207      default_factory=lambda: AutoShardPolicy.AUTO)
208
209  num_devices = options_lib.create_option(
210      name="num_devices",
211      ty=int,
212      docstring=
213      "The number of devices attached to this input pipeline. This will be "
214      "automatically set by `MultiDeviceIterator`.")
215
216  def _to_proto(self):
217    pb = dataset_options_pb2.DistributeOptions()
218    pb.auto_shard_policy = AutoShardPolicy._to_proto(self.auto_shard_policy)  # pylint: disable=protected-access
219    if self.num_devices is not None:
220      pb.num_devices = self.num_devices
221    return pb
222
223  def _from_proto(self, pb):
224    self.auto_shard_policy = AutoShardPolicy._from_proto(pb.auto_shard_policy)  # pylint: disable=protected-access
225    if pb.WhichOneof("optional_num_devices") is not None:
226      self.num_devices = pb.num_devices
227
228
229@tf_export("data.experimental.OptimizationOptions")
230class OptimizationOptions(options_lib.OptionsBase):
231  """Represents options for dataset optimizations.
232
233  You can set the optimization options of a dataset through the
234  `experimental_optimization` property of `tf.data.Options`; the property is
235  an instance of `tf.data.experimental.OptimizationOptions`.
236
237  ```python
238  options = tf.data.Options()
239  options.experimental_optimization.noop_elimination = True
240  options.experimental_optimization.apply_default_optimizations = False
241  dataset = dataset.with_options(options)
242  ```
243  """
244  apply_default_optimizations = options_lib.create_option(
245      name="apply_default_optimizations",
246      ty=bool,
247      docstring=
248      "Whether to apply default graph optimizations. If False, only graph "
249      "optimizations that have been explicitly enabled will be applied.")
250
251  filter_fusion = options_lib.create_option(
252      name="filter_fusion",
253      ty=bool,
254      docstring=
255      "Whether to fuse filter transformations. If None, defaults to False.")
256
257  map_and_batch_fusion = options_lib.create_option(
258      name="map_and_batch_fusion",
259      ty=bool,
260      docstring=
261      "Whether to fuse map and batch transformations. If None, defaults to "
262      "True.")
263
264  map_and_filter_fusion = options_lib.create_option(
265      name="map_and_filter_fusion",
266      ty=bool,
267      docstring=
268      "Whether to fuse map and filter transformations. If None, defaults to "
269      "False.")
270
271  map_fusion = options_lib.create_option(
272      name="map_fusion",
273      ty=bool,
274      docstring="Whether to fuse map transformations. If None, defaults to "
275      "False.")
276
277  map_parallelization = options_lib.create_option(
278      name="map_parallelization",
279      ty=bool,
280      docstring=
281      "Whether to parallelize stateless map transformations. If None, defaults "
282      "to True.")
283
284  noop_elimination = options_lib.create_option(
285      name="noop_elimination",
286      ty=bool,
287      docstring=
288      "Whether to eliminate no-op transformations. If None, defaults to True.")
289
290  parallel_batch = options_lib.create_option(
291      name="parallel_batch",
292      ty=bool,
293      docstring="Whether to parallelize copying of batch elements. This "
294      "optimization is highly experimental and can cause performance "
295      "degradation (e.g. when the parallelization overhead exceeds the "
296      "benefits of performing the data copies in parallel). You should only "
297      "enable this optimization if a) your input pipeline is bottlenecked on "
298      "batching and b) you have validated that this optimization improves "
299      "performance. If None, defaults to False.")
300
301  shuffle_and_repeat_fusion = options_lib.create_option(
302      name="shuffle_and_repeat_fusion",
303      ty=bool,
304      docstring="Whether to fuse shuffle and repeat transformations. If None, "
305      "defaults to True.")
306
307  def _to_proto(self):
308    pb = dataset_options_pb2.OptimizationOptions()
309    if self.apply_default_optimizations is not None:
310      pb.apply_default_optimizations = self.apply_default_optimizations
311    if self.filter_fusion is not None:
312      pb.filter_fusion = self.filter_fusion
313    if self.map_and_batch_fusion is not None:
314      pb.map_and_batch_fusion = self.map_and_batch_fusion
315    if self.map_and_filter_fusion is not None:
316      pb.map_and_filter_fusion = self.map_and_filter_fusion
317    if self.map_fusion is not None:
318      pb.map_fusion = self.map_fusion
319    if self.map_parallelization is not None:
320      pb.map_parallelization = self.map_parallelization
321    if self.noop_elimination is not None:
322      pb.noop_elimination = self.noop_elimination
323    if self.parallel_batch is not None:
324      pb.parallel_batch = self.parallel_batch
325    if self.shuffle_and_repeat_fusion is not None:
326      pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion
327    return pb
328
329  def _from_proto(self, pb):
330    if pb.WhichOneof("optional_apply_default_optimizations") is not None:
331      self.apply_default_optimizations = pb.apply_default_optimizations
332    if pb.WhichOneof("optional_filter_fusion") is not None:
333      self.filter_fusion = pb.filter_fusion
334    if pb.WhichOneof("optional_map_and_batch_fusion") is not None:
335      self.map_and_batch_fusion = pb.map_and_batch_fusion
336    if pb.WhichOneof("optional_map_and_filter_fusion") is not None:
337      self.map_and_filter_fusion = pb.map_and_filter_fusion
338    if pb.WhichOneof("optional_map_fusion") is not None:
339      self.map_fusion = pb.map_fusion
340    if pb.WhichOneof("optional_map_parallelization") is not None:
341      self.map_parallelization = pb.map_parallelization
342    if pb.WhichOneof("optional_noop_elimination") is not None:
343      self.noop_elimination = pb.noop_elimination
344    if pb.WhichOneof("optional_parallel_batch") is not None:
345      self.parallel_batch = pb.parallel_batch
346    if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
347      self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion
348
349  def _set_mutable(self, mutable):
350    """Change the mutability value to `mutable` on this options and children."""
351    # pylint: disable=protected-access
352    object.__setattr__(self, "_mutable", mutable)
353
354
355@deprecation.deprecated_endpoints("data.experimental.ThreadingOptions")
356@tf_export("data.experimental.ThreadingOptions", "data.ThreadingOptions")
357class ThreadingOptions(options_lib.OptionsBase):
358  """Represents options for dataset threading.
359
360  You can set the threading options of a dataset through the
361  `experimental_threading` property of `tf.data.Options`; the property is
362  an instance of `tf.data.ThreadingOptions`.
363
364  ```python
365  options = tf.data.Options()
366  options.threading.private_threadpool_size = 10
367  dataset = dataset.with_options(options)
368  ```
369  """
370
371  max_intra_op_parallelism = options_lib.create_option(
372      name="max_intra_op_parallelism",
373      ty=int,
374      docstring=
375      "If set, it overrides the maximum degree of intra-op parallelism.")
376
377  private_threadpool_size = options_lib.create_option(
378      name="private_threadpool_size",
379      ty=int,
380      docstring=
381      "If set, the dataset will use a private threadpool of the given size. "
382      "The value 0 can be used to indicate that the threadpool size should be "
383      "determined at runtime based on the number of available CPU cores.")
384
385  def _to_proto(self):
386    pb = dataset_options_pb2.ThreadingOptions()
387    if self.max_intra_op_parallelism is not None:
388      pb.max_intra_op_parallelism = self.max_intra_op_parallelism
389    if self.private_threadpool_size is not None:
390      pb.private_threadpool_size = self.private_threadpool_size
391    return pb
392
393  def _from_proto(self, pb):
394    if pb.WhichOneof("optional_max_intra_op_parallelism") is not None:
395      self.max_intra_op_parallelism = pb.max_intra_op_parallelism
396    if pb.WhichOneof("optional_private_threadpool_size") is not None:
397      self.private_threadpool_size = pb.private_threadpool_size
398
399
400@tf_export("data.Options")
401class Options(options_lib.OptionsBase):
402  """Represents options for `tf.data.Dataset`.
403
404  A `tf.data.Options` object can be, for instance, used to control which static
405  optimizations to apply to the input pipeline graph or whether to use
406  performance modeling to dynamically tune the parallelism of operations such as
407  `tf.data.Dataset.map` or `tf.data.Dataset.interleave`.
408
409  The options are set for the entire dataset and are carried over to datasets
410  created through tf.data transformations.
411
412  The options can be set by constructing an `Options` object and using the
413  `tf.data.Dataset.with_options(options)` transformation, which returns a
414  dataset with the options set.
415
416  >>> dataset = tf.data.Dataset.range(42)
417  >>> options = tf.data.Options()
418  >>> options.deterministic = False
419  >>> dataset = dataset.with_options(options)
420  >>> print(dataset.options().deterministic)
421  False
422
423  Note: A known limitation of the `tf.data.Options` implementation is that the
424  options are not preserved across tf.function boundaries. In particular, to
425  set options for a dataset that is iterated within a tf.function, the options
426  need to be set within the same tf.function.
427  """
428
429  autotune = options_lib.create_option(
430      name="autotune",
431      ty=AutotuneOptions,
432      docstring="The autotuning options associated with the dataset. See "
433      "`tf.data.experimental.AutotuneOptions` for more details.",
434      default_factory=AutotuneOptions)
435
436  deterministic = options_lib.create_option(
437      name="deterministic",
438      ty=bool,
439      docstring=
440      "Whether the outputs need to be produced in deterministic order. If None,"
441      " defaults to True.")
442
443  experimental_deterministic = options_lib.create_option(
444      name="experimental_deterministic",
445      ty=bool,
446      docstring="DEPRECATED. Use `deterministic` instead.")
447
448  experimental_distribute = options_lib.create_option(
449      name="experimental_distribute",
450      ty=DistributeOptions,
451      docstring=
452      "The distribution strategy options associated with the dataset. See "
453      "`tf.data.experimental.DistributeOptions` for more details.",
454      default_factory=DistributeOptions)
455
456  experimental_external_state_policy = options_lib.create_option(
457      name="experimental_external_state_policy",
458      ty=ExternalStatePolicy,
459      docstring="This option can be used to override the default policy for "
460      "how to handle external state when serializing a dataset or "
461      "checkpointing its iterator. There are three settings available - "
462      "IGNORE: External state is ignored without a warning; WARN: External "
463      "state is ignored and a warning is logged; FAIL: External state results "
464      "in an error.")
465
466  experimental_optimization = options_lib.create_option(
467      name="experimental_optimization",
468      ty=OptimizationOptions,
469      docstring=
470      "The optimization options associated with the dataset. See "
471      "`tf.data.experimental.OptimizationOptions` for more details.",
472      default_factory=OptimizationOptions)
473
474  experimental_slack = options_lib.create_option(
475      name="experimental_slack",
476      ty=bool,
477      docstring="Whether to introduce 'slack' in the last `prefetch` of the "
478      "input pipeline, if it exists. This may reduce CPU contention with "
479      "accelerator host-side activity at the start of a step. The slack "
480      "frequency is determined by the number of devices attached to this "
481      "input pipeline. If None, defaults to False.")
482
483  experimental_threading = options_lib.create_option(
484      name="experimental_threading",
485      ty=ThreadingOptions,
486      docstring="DEPRECATED. Use `threading` instead.")
487
488  threading = options_lib.create_option(
489      name="threading",
490      ty=ThreadingOptions,
491      docstring="The threading options associated with the dataset. See "
492      "`tf.data.ThreadingOptions` for more details.",
493      default_factory=ThreadingOptions)
494
495  def __getattribute__(self, name):
496    if name == "experimental_threading":
497      logging.warning("options.experimental_threading is deprecated. "
498                      "Use options.threading instead.")
499      return getattr(self, "threading")
500    if name == "experimental_deterministic":
501      # TODO(aaudibert): Uncomment after internal uses have been updated.
502      # logging.warning("options.experimental_deterministic is deprecated. "
503      #                 "Use options.deterministic instead.")
504      return getattr(self, "deterministic")
505    return super(Options, self).__getattribute__(name)
506
507  def __setattr__(self, name, value):
508    if name == "experimental_threading":
509      logging.warning("options.experimental_threading is deprecated. "
510                      "Use options.threading instead.")
511      super(Options, self).__setattr__("threading", value)
512      return
513    if name == "experimental_deterministic":
514      # TODO(aaudibert): Uncomment after internal uses have been updated.
515      # logging.warning("options.experimental_deterministic is deprecated. "
516      #                 "Use options.deterministic instead.")
517      super(Options, self).__setattr__("deterministic", value)
518      return
519    super(Options, self).__setattr__(name, value)
520
521  def _to_proto(self):
522    pb = dataset_options_pb2.Options()
523    if self.deterministic is not None:
524      pb.deterministic = self.deterministic
525    pb.autotune_options.CopyFrom(self.autotune._to_proto())  # pylint: disable=protected-access
526    pb.distribute_options.CopyFrom(self.experimental_distribute._to_proto())  # pylint: disable=protected-access
527    if self.experimental_external_state_policy is not None:
528      pb.external_state_policy = (
529          ExternalStatePolicy._to_proto(  # pylint: disable=protected-access
530              self.experimental_external_state_policy))
531    pb.optimization_options.CopyFrom(self.experimental_optimization._to_proto())  # pylint: disable=protected-access
532    if self.experimental_slack is not None:
533      pb.slack = self.experimental_slack
534    pb.threading_options.CopyFrom(self.threading._to_proto())  # pylint: disable=protected-access
535    return pb
536
537  def _from_proto(self, pb):
538    if pb.WhichOneof("optional_deterministic") is not None:
539      self.deterministic = pb.deterministic
540    self.autotune._from_proto(pb.autotune_options)  # pylint: disable=protected-access
541    self.experimental_distribute._from_proto(pb.distribute_options)  # pylint: disable=protected-access
542    if pb.WhichOneof("optional_external_state_policy") is not None:
543      self.experimental_external_state_policy = (
544          ExternalStatePolicy._from_proto(  # pylint: disable=protected-access
545              pb.external_state_policy))
546    self.experimental_optimization._from_proto(pb.optimization_options)  # pylint: disable=protected-access
547    if pb.WhichOneof("optional_slack") is not None:
548      self.experimental_slack = pb.slack
549    self.threading._from_proto(pb.threading_options)  # pylint: disable=protected-access
550
551  def _set_mutable(self, mutable):
552    """Change the mutability value to `mutable` on this options and children."""
553    # pylint: disable=protected-access
554    object.__setattr__(self, "_mutable", mutable)
555    self.autotune._set_mutable(mutable)
556    self.experimental_distribute._set_mutable(mutable)
557    self.experimental_optimization._set_mutable(mutable)
558    self.threading._set_mutable(mutable)
559
560  def merge(self, options):
561    """Merges itself with the given `tf.data.Options`.
562
563    If this object and the `options` to merge set an option differently, a
564    warning is generated and this object's value is updated with the `options`
565    object's value.
566
567    Args:
568      options: The `tf.data.Options` to merge with.
569
570    Returns:
571      New `tf.data.Options` object which is the result of merging self with
572      the input `tf.data.Options`.
573    """
574    return options_lib.merge_options(self, options)
575