1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""API for specifying `tf.data` options.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import enum 22 23from absl import logging 24 25from tensorflow.core.framework import dataset_options_pb2 26from tensorflow.python.data.util import options as options_lib 27from tensorflow.python.util import deprecation 28from tensorflow.python.util.tf_export import tf_export 29 30 31@tf_export("data.experimental.AutoShardPolicy") 32class AutoShardPolicy(enum.IntEnum): 33 """Represents the type of auto-sharding to use. 34 35 OFF: No sharding will be performed. 36 37 AUTO: Attempts FILE-based sharding, falling back to DATA-based sharding. 38 39 FILE: Shards by input files (i.e. each worker will get a set of files to 40 process). When this option is selected, make sure that there is at least as 41 many files as workers. If there are fewer input files than workers, a runtime 42 error will be raised. 43 44 DATA: Shards by elements produced by the dataset. Each worker will process the 45 whole dataset and discard the portion that is not for itself. Note that for 46 this mode to correctly partitions the dataset elements, the dataset needs to 47 produce elements in a deterministic order. 48 49 HINT: Looks for the presence of `shard(SHARD_HINT, ...)` which is treated as a 50 placeholder to replace with `shard(num_workers, worker_index)`. 51 """ 52 53 # LINT.IfChange 54 OFF = -1 55 AUTO = 0 56 FILE = 1 57 DATA = 2 58 HINT = 3 59 # LINT.ThenChange(//tensorflow/python/data/experimental/ops/data_service_ops.py:tf_data_service_sharding_policy) 60 61 @classmethod 62 def _to_proto(cls, obj): 63 """Convert enum to proto.""" 64 if obj == cls.OFF: 65 return dataset_options_pb2.AutoShardPolicy.OFF 66 if obj == cls.FILE: 67 return dataset_options_pb2.AutoShardPolicy.FILE 68 if obj == cls.DATA: 69 return dataset_options_pb2.AutoShardPolicy.DATA 70 if obj == cls.AUTO: 71 return dataset_options_pb2.AutoShardPolicy.AUTO 72 if obj == cls.HINT: 73 return dataset_options_pb2.AutoShardPolicy.HINT 74 raise ValueError("%s._to_proto() is called with undefined enum %s." % 75 (cls.__name__, obj.name)) 76 77 @classmethod 78 def _from_proto(cls, pb): 79 """Convert proto to enum.""" 80 if pb == dataset_options_pb2.AutoShardPolicy.OFF: 81 return cls.OFF 82 if pb == dataset_options_pb2.AutoShardPolicy.FILE: 83 return cls.FILE 84 if pb == dataset_options_pb2.AutoShardPolicy.DATA: 85 return cls.DATA 86 if pb == dataset_options_pb2.AutoShardPolicy.AUTO: 87 return cls.AUTO 88 if pb == dataset_options_pb2.AutoShardPolicy.HINT: 89 return cls.HINT 90 raise ValueError("%s._from_proto() is called with undefined enum %s." % 91 (cls.__name__, pb)) 92 93 94@tf_export("data.experimental.ExternalStatePolicy") 95class ExternalStatePolicy(enum.Enum): 96 """Represents how to handle external state during serialization. 97 98 See the `tf.data.Options.experimental_external_state_policy` documentation 99 for more information. 100 """ 101 WARN = 0 102 IGNORE = 1 103 FAIL = 2 104 105 @classmethod 106 def _to_proto(cls, obj): 107 """Convert enum to proto.""" 108 if obj == cls.IGNORE: 109 return dataset_options_pb2.ExternalStatePolicy.POLICY_IGNORE 110 if obj == cls.FAIL: 111 return dataset_options_pb2.ExternalStatePolicy.POLICY_FAIL 112 if obj == cls.WARN: 113 return dataset_options_pb2.ExternalStatePolicy.POLICY_WARN 114 raise ValueError("%s._to_proto() is called with undefined enum %s." % 115 (cls.__name__, obj.name)) 116 117 @classmethod 118 def _from_proto(cls, pb): 119 """Convert proto to enum.""" 120 if pb == dataset_options_pb2.ExternalStatePolicy.POLICY_IGNORE: 121 return cls.IGNORE 122 if pb == dataset_options_pb2.ExternalStatePolicy.POLICY_FAIL: 123 return cls.FAIL 124 if pb == dataset_options_pb2.ExternalStatePolicy.POLICY_WARN: 125 return cls.WARN 126 raise ValueError("%s._from_proto() is called with undefined enum %s." % 127 (cls.__name__, pb)) 128 129 130@tf_export("data.experimental.AutotuneOptions") 131class AutotuneOptions(options_lib.OptionsBase): 132 """Represents options for autotuning dataset performance. 133 134 ```python 135 options = tf.data.Options() 136 options.autotune.enabled = False 137 dataset = dataset.with_options(options) 138 ``` 139 """ 140 141 enabled = options_lib.create_option( 142 name="enabled", 143 ty=bool, 144 docstring="Whether to automatically tune performance knobs. If None, " 145 "defaults to True.") 146 147 cpu_budget = options_lib.create_option( 148 name="cpu_budget", 149 ty=int, 150 docstring="When autotuning is enabled (through `autotune`), determines " 151 "the CPU budget to use. Values greater than the number of schedulable " 152 "CPU cores are allowed but may result in CPU contention. If None, " 153 "defaults to the number of schedulable CPU cores.") 154 155 ram_budget = options_lib.create_option( 156 name="ram_budget", 157 ty=int, 158 docstring="When autotuning is enabled (through `autotune`), determines " 159 "the RAM budget to use. Values greater than the available RAM in bytes " 160 "may result in OOM. If None, defaults to half of the available RAM in " 161 "bytes.") 162 163 def _to_proto(self): 164 pb = dataset_options_pb2.AutotuneOptions() 165 if self.enabled is not None: 166 pb.enabled = self.enabled 167 if self.cpu_budget is not None: 168 pb.cpu_budget = self.cpu_budget 169 if self.ram_budget is not None: 170 pb.ram_budget = self.ram_budget 171 return pb 172 173 def _from_proto(self, pb): 174 if pb.WhichOneof("optional_enabled") is not None: 175 self.enabled = pb.enabled 176 if pb.WhichOneof("optional_cpu_budget") is not None: 177 self.cpu_budget = pb.cpu_budget 178 if pb.WhichOneof("optional_ram_budget") is not None: 179 self.ram_budget = pb.ram_budget 180 181 def _set_mutable(self, mutable): 182 """Change the mutability value to `mutable` on this options and children.""" 183 # pylint: disable=protected-access 184 object.__setattr__(self, "_mutable", mutable) 185 186 187@tf_export("data.experimental.DistributeOptions") 188class DistributeOptions(options_lib.OptionsBase): 189 """Represents options for distributed data processing. 190 191 You can set the distribution options of a dataset through the 192 `experimental_distribute` property of `tf.data.Options`; the property is 193 an instance of `tf.data.experimental.DistributeOptions`. 194 195 ```python 196 options = tf.data.Options() 197 options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF 198 dataset = dataset.with_options(options) 199 ``` 200 """ 201 202 auto_shard_policy = options_lib.create_option( 203 name="auto_shard_policy", 204 ty=AutoShardPolicy, 205 docstring="The type of sharding to use. See " 206 "`tf.data.experimental.AutoShardPolicy` for additional information.", 207 default_factory=lambda: AutoShardPolicy.AUTO) 208 209 num_devices = options_lib.create_option( 210 name="num_devices", 211 ty=int, 212 docstring= 213 "The number of devices attached to this input pipeline. This will be " 214 "automatically set by `MultiDeviceIterator`.") 215 216 def _to_proto(self): 217 pb = dataset_options_pb2.DistributeOptions() 218 pb.auto_shard_policy = AutoShardPolicy._to_proto(self.auto_shard_policy) # pylint: disable=protected-access 219 if self.num_devices is not None: 220 pb.num_devices = self.num_devices 221 return pb 222 223 def _from_proto(self, pb): 224 self.auto_shard_policy = AutoShardPolicy._from_proto(pb.auto_shard_policy) # pylint: disable=protected-access 225 if pb.WhichOneof("optional_num_devices") is not None: 226 self.num_devices = pb.num_devices 227 228 229@tf_export("data.experimental.OptimizationOptions") 230class OptimizationOptions(options_lib.OptionsBase): 231 """Represents options for dataset optimizations. 232 233 You can set the optimization options of a dataset through the 234 `experimental_optimization` property of `tf.data.Options`; the property is 235 an instance of `tf.data.experimental.OptimizationOptions`. 236 237 ```python 238 options = tf.data.Options() 239 options.experimental_optimization.noop_elimination = True 240 options.experimental_optimization.apply_default_optimizations = False 241 dataset = dataset.with_options(options) 242 ``` 243 """ 244 apply_default_optimizations = options_lib.create_option( 245 name="apply_default_optimizations", 246 ty=bool, 247 docstring= 248 "Whether to apply default graph optimizations. If False, only graph " 249 "optimizations that have been explicitly enabled will be applied.") 250 251 filter_fusion = options_lib.create_option( 252 name="filter_fusion", 253 ty=bool, 254 docstring= 255 "Whether to fuse filter transformations. If None, defaults to False.") 256 257 map_and_batch_fusion = options_lib.create_option( 258 name="map_and_batch_fusion", 259 ty=bool, 260 docstring= 261 "Whether to fuse map and batch transformations. If None, defaults to " 262 "True.") 263 264 map_and_filter_fusion = options_lib.create_option( 265 name="map_and_filter_fusion", 266 ty=bool, 267 docstring= 268 "Whether to fuse map and filter transformations. If None, defaults to " 269 "False.") 270 271 map_fusion = options_lib.create_option( 272 name="map_fusion", 273 ty=bool, 274 docstring="Whether to fuse map transformations. If None, defaults to " 275 "False.") 276 277 map_parallelization = options_lib.create_option( 278 name="map_parallelization", 279 ty=bool, 280 docstring= 281 "Whether to parallelize stateless map transformations. If None, defaults " 282 "to True.") 283 284 noop_elimination = options_lib.create_option( 285 name="noop_elimination", 286 ty=bool, 287 docstring= 288 "Whether to eliminate no-op transformations. If None, defaults to True.") 289 290 parallel_batch = options_lib.create_option( 291 name="parallel_batch", 292 ty=bool, 293 docstring="Whether to parallelize copying of batch elements. This " 294 "optimization is highly experimental and can cause performance " 295 "degradation (e.g. when the parallelization overhead exceeds the " 296 "benefits of performing the data copies in parallel). You should only " 297 "enable this optimization if a) your input pipeline is bottlenecked on " 298 "batching and b) you have validated that this optimization improves " 299 "performance. If None, defaults to False.") 300 301 shuffle_and_repeat_fusion = options_lib.create_option( 302 name="shuffle_and_repeat_fusion", 303 ty=bool, 304 docstring="Whether to fuse shuffle and repeat transformations. If None, " 305 "defaults to True.") 306 307 def _to_proto(self): 308 pb = dataset_options_pb2.OptimizationOptions() 309 if self.apply_default_optimizations is not None: 310 pb.apply_default_optimizations = self.apply_default_optimizations 311 if self.filter_fusion is not None: 312 pb.filter_fusion = self.filter_fusion 313 if self.map_and_batch_fusion is not None: 314 pb.map_and_batch_fusion = self.map_and_batch_fusion 315 if self.map_and_filter_fusion is not None: 316 pb.map_and_filter_fusion = self.map_and_filter_fusion 317 if self.map_fusion is not None: 318 pb.map_fusion = self.map_fusion 319 if self.map_parallelization is not None: 320 pb.map_parallelization = self.map_parallelization 321 if self.noop_elimination is not None: 322 pb.noop_elimination = self.noop_elimination 323 if self.parallel_batch is not None: 324 pb.parallel_batch = self.parallel_batch 325 if self.shuffle_and_repeat_fusion is not None: 326 pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion 327 return pb 328 329 def _from_proto(self, pb): 330 if pb.WhichOneof("optional_apply_default_optimizations") is not None: 331 self.apply_default_optimizations = pb.apply_default_optimizations 332 if pb.WhichOneof("optional_filter_fusion") is not None: 333 self.filter_fusion = pb.filter_fusion 334 if pb.WhichOneof("optional_map_and_batch_fusion") is not None: 335 self.map_and_batch_fusion = pb.map_and_batch_fusion 336 if pb.WhichOneof("optional_map_and_filter_fusion") is not None: 337 self.map_and_filter_fusion = pb.map_and_filter_fusion 338 if pb.WhichOneof("optional_map_fusion") is not None: 339 self.map_fusion = pb.map_fusion 340 if pb.WhichOneof("optional_map_parallelization") is not None: 341 self.map_parallelization = pb.map_parallelization 342 if pb.WhichOneof("optional_noop_elimination") is not None: 343 self.noop_elimination = pb.noop_elimination 344 if pb.WhichOneof("optional_parallel_batch") is not None: 345 self.parallel_batch = pb.parallel_batch 346 if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None: 347 self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion 348 349 def _set_mutable(self, mutable): 350 """Change the mutability value to `mutable` on this options and children.""" 351 # pylint: disable=protected-access 352 object.__setattr__(self, "_mutable", mutable) 353 354 355@deprecation.deprecated_endpoints("data.experimental.ThreadingOptions") 356@tf_export("data.experimental.ThreadingOptions", "data.ThreadingOptions") 357class ThreadingOptions(options_lib.OptionsBase): 358 """Represents options for dataset threading. 359 360 You can set the threading options of a dataset through the 361 `experimental_threading` property of `tf.data.Options`; the property is 362 an instance of `tf.data.ThreadingOptions`. 363 364 ```python 365 options = tf.data.Options() 366 options.threading.private_threadpool_size = 10 367 dataset = dataset.with_options(options) 368 ``` 369 """ 370 371 max_intra_op_parallelism = options_lib.create_option( 372 name="max_intra_op_parallelism", 373 ty=int, 374 docstring= 375 "If set, it overrides the maximum degree of intra-op parallelism.") 376 377 private_threadpool_size = options_lib.create_option( 378 name="private_threadpool_size", 379 ty=int, 380 docstring= 381 "If set, the dataset will use a private threadpool of the given size. " 382 "The value 0 can be used to indicate that the threadpool size should be " 383 "determined at runtime based on the number of available CPU cores.") 384 385 def _to_proto(self): 386 pb = dataset_options_pb2.ThreadingOptions() 387 if self.max_intra_op_parallelism is not None: 388 pb.max_intra_op_parallelism = self.max_intra_op_parallelism 389 if self.private_threadpool_size is not None: 390 pb.private_threadpool_size = self.private_threadpool_size 391 return pb 392 393 def _from_proto(self, pb): 394 if pb.WhichOneof("optional_max_intra_op_parallelism") is not None: 395 self.max_intra_op_parallelism = pb.max_intra_op_parallelism 396 if pb.WhichOneof("optional_private_threadpool_size") is not None: 397 self.private_threadpool_size = pb.private_threadpool_size 398 399 400@tf_export("data.Options") 401class Options(options_lib.OptionsBase): 402 """Represents options for `tf.data.Dataset`. 403 404 A `tf.data.Options` object can be, for instance, used to control which static 405 optimizations to apply to the input pipeline graph or whether to use 406 performance modeling to dynamically tune the parallelism of operations such as 407 `tf.data.Dataset.map` or `tf.data.Dataset.interleave`. 408 409 The options are set for the entire dataset and are carried over to datasets 410 created through tf.data transformations. 411 412 The options can be set by constructing an `Options` object and using the 413 `tf.data.Dataset.with_options(options)` transformation, which returns a 414 dataset with the options set. 415 416 >>> dataset = tf.data.Dataset.range(42) 417 >>> options = tf.data.Options() 418 >>> options.deterministic = False 419 >>> dataset = dataset.with_options(options) 420 >>> print(dataset.options().deterministic) 421 False 422 423 Note: A known limitation of the `tf.data.Options` implementation is that the 424 options are not preserved across tf.function boundaries. In particular, to 425 set options for a dataset that is iterated within a tf.function, the options 426 need to be set within the same tf.function. 427 """ 428 429 autotune = options_lib.create_option( 430 name="autotune", 431 ty=AutotuneOptions, 432 docstring="The autotuning options associated with the dataset. See " 433 "`tf.data.experimental.AutotuneOptions` for more details.", 434 default_factory=AutotuneOptions) 435 436 deterministic = options_lib.create_option( 437 name="deterministic", 438 ty=bool, 439 docstring= 440 "Whether the outputs need to be produced in deterministic order. If None," 441 " defaults to True.") 442 443 experimental_deterministic = options_lib.create_option( 444 name="experimental_deterministic", 445 ty=bool, 446 docstring="DEPRECATED. Use `deterministic` instead.") 447 448 experimental_distribute = options_lib.create_option( 449 name="experimental_distribute", 450 ty=DistributeOptions, 451 docstring= 452 "The distribution strategy options associated with the dataset. See " 453 "`tf.data.experimental.DistributeOptions` for more details.", 454 default_factory=DistributeOptions) 455 456 experimental_external_state_policy = options_lib.create_option( 457 name="experimental_external_state_policy", 458 ty=ExternalStatePolicy, 459 docstring="This option can be used to override the default policy for " 460 "how to handle external state when serializing a dataset or " 461 "checkpointing its iterator. There are three settings available - " 462 "IGNORE: External state is ignored without a warning; WARN: External " 463 "state is ignored and a warning is logged; FAIL: External state results " 464 "in an error.") 465 466 experimental_optimization = options_lib.create_option( 467 name="experimental_optimization", 468 ty=OptimizationOptions, 469 docstring= 470 "The optimization options associated with the dataset. See " 471 "`tf.data.experimental.OptimizationOptions` for more details.", 472 default_factory=OptimizationOptions) 473 474 experimental_slack = options_lib.create_option( 475 name="experimental_slack", 476 ty=bool, 477 docstring="Whether to introduce 'slack' in the last `prefetch` of the " 478 "input pipeline, if it exists. This may reduce CPU contention with " 479 "accelerator host-side activity at the start of a step. The slack " 480 "frequency is determined by the number of devices attached to this " 481 "input pipeline. If None, defaults to False.") 482 483 experimental_threading = options_lib.create_option( 484 name="experimental_threading", 485 ty=ThreadingOptions, 486 docstring="DEPRECATED. Use `threading` instead.") 487 488 threading = options_lib.create_option( 489 name="threading", 490 ty=ThreadingOptions, 491 docstring="The threading options associated with the dataset. See " 492 "`tf.data.ThreadingOptions` for more details.", 493 default_factory=ThreadingOptions) 494 495 def __getattribute__(self, name): 496 if name == "experimental_threading": 497 logging.warning("options.experimental_threading is deprecated. " 498 "Use options.threading instead.") 499 return getattr(self, "threading") 500 if name == "experimental_deterministic": 501 # TODO(aaudibert): Uncomment after internal uses have been updated. 502 # logging.warning("options.experimental_deterministic is deprecated. " 503 # "Use options.deterministic instead.") 504 return getattr(self, "deterministic") 505 return super(Options, self).__getattribute__(name) 506 507 def __setattr__(self, name, value): 508 if name == "experimental_threading": 509 logging.warning("options.experimental_threading is deprecated. " 510 "Use options.threading instead.") 511 super(Options, self).__setattr__("threading", value) 512 return 513 if name == "experimental_deterministic": 514 # TODO(aaudibert): Uncomment after internal uses have been updated. 515 # logging.warning("options.experimental_deterministic is deprecated. " 516 # "Use options.deterministic instead.") 517 super(Options, self).__setattr__("deterministic", value) 518 return 519 super(Options, self).__setattr__(name, value) 520 521 def _to_proto(self): 522 pb = dataset_options_pb2.Options() 523 if self.deterministic is not None: 524 pb.deterministic = self.deterministic 525 pb.autotune_options.CopyFrom(self.autotune._to_proto()) # pylint: disable=protected-access 526 pb.distribute_options.CopyFrom(self.experimental_distribute._to_proto()) # pylint: disable=protected-access 527 if self.experimental_external_state_policy is not None: 528 pb.external_state_policy = ( 529 ExternalStatePolicy._to_proto( # pylint: disable=protected-access 530 self.experimental_external_state_policy)) 531 pb.optimization_options.CopyFrom(self.experimental_optimization._to_proto()) # pylint: disable=protected-access 532 if self.experimental_slack is not None: 533 pb.slack = self.experimental_slack 534 pb.threading_options.CopyFrom(self.threading._to_proto()) # pylint: disable=protected-access 535 return pb 536 537 def _from_proto(self, pb): 538 if pb.WhichOneof("optional_deterministic") is not None: 539 self.deterministic = pb.deterministic 540 self.autotune._from_proto(pb.autotune_options) # pylint: disable=protected-access 541 self.experimental_distribute._from_proto(pb.distribute_options) # pylint: disable=protected-access 542 if pb.WhichOneof("optional_external_state_policy") is not None: 543 self.experimental_external_state_policy = ( 544 ExternalStatePolicy._from_proto( # pylint: disable=protected-access 545 pb.external_state_policy)) 546 self.experimental_optimization._from_proto(pb.optimization_options) # pylint: disable=protected-access 547 if pb.WhichOneof("optional_slack") is not None: 548 self.experimental_slack = pb.slack 549 self.threading._from_proto(pb.threading_options) # pylint: disable=protected-access 550 551 def _set_mutable(self, mutable): 552 """Change the mutability value to `mutable` on this options and children.""" 553 # pylint: disable=protected-access 554 object.__setattr__(self, "_mutable", mutable) 555 self.autotune._set_mutable(mutable) 556 self.experimental_distribute._set_mutable(mutable) 557 self.experimental_optimization._set_mutable(mutable) 558 self.threading._set_mutable(mutable) 559 560 def merge(self, options): 561 """Merges itself with the given `tf.data.Options`. 562 563 If this object and the `options` to merge set an option differently, a 564 warning is generated and this object's value is updated with the `options` 565 object's value. 566 567 Args: 568 options: The `tf.data.Options` to merge with. 569 570 Returns: 571 New `tf.data.Options` object which is the result of merging self with 572 the input `tf.data.Options`. 573 """ 574 return options_lib.merge_options(self, options) 575