1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for Datasets.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import functools 22import sys 23import threading 24import warnings 25import weakref 26 27import numpy as np 28import six 29from six.moves import queue as Queue # pylint: disable=redefined-builtin 30 31from tensorflow.core.framework import graph_pb2 32from tensorflow.python import tf2 33from tensorflow.python.compat import compat 34from tensorflow.python.data.experimental.ops import distribute_options 35from tensorflow.python.data.experimental.ops import optimization_options 36from tensorflow.python.data.experimental.ops import stats_options 37from tensorflow.python.data.experimental.ops import threading_options 38from tensorflow.python.data.ops import iterator_ops 39from tensorflow.python.data.util import nest 40from tensorflow.python.data.util import options as options_lib 41from tensorflow.python.data.util import random_seed 42from tensorflow.python.data.util import sparse 43from tensorflow.python.data.util import structure 44from tensorflow.python.data.util import traverse 45from tensorflow.python.eager import context 46from tensorflow.python.eager import function as eager_function 47from tensorflow.python.framework import auto_control_deps 48from tensorflow.python.framework import composite_tensor 49from tensorflow.python.framework import constant_op 50from tensorflow.python.framework import dtypes 51from tensorflow.python.framework import function 52from tensorflow.python.framework import ops 53from tensorflow.python.framework import random_seed as core_random_seed 54from tensorflow.python.framework import smart_cond 55from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 56from tensorflow.python.framework import tensor_shape 57from tensorflow.python.framework import tensor_spec 58from tensorflow.python.framework import tensor_util 59from tensorflow.python.framework import type_spec 60from tensorflow.python.ops import array_ops 61from tensorflow.python.ops import control_flow_ops 62from tensorflow.python.ops import gen_dataset_ops 63from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 64from tensorflow.python.ops import gen_io_ops 65from tensorflow.python.ops import math_ops 66from tensorflow.python.ops import script_ops 67from tensorflow.python.ops import string_ops 68from tensorflow.python.training.tracking import base as tracking_base 69from tensorflow.python.training.tracking import tracking 70from tensorflow.python.util import deprecation 71from tensorflow.python.util import function_utils 72from tensorflow.python.util import lazy_loader 73from tensorflow.python.util import nest as tf_nest 74from tensorflow.python.util.tf_export import tf_export 75 76# Loaded lazily due to a circular dependency (roughly 77# tf.function->wrap_function->dataset->autograph->tf.function). 78# TODO(b/133251390): Use a regular import. 79wrap_function = lazy_loader.LazyLoader( 80 "wrap_function", globals(), 81 "tensorflow.python.eager.wrap_function") 82# TODO(mdan): Create a public API for this. 83autograph_ctx = lazy_loader.LazyLoader( 84 "autograph_ctx", globals(), 85 "tensorflow.python.autograph.core.ag_ctx") 86autograph = lazy_loader.LazyLoader( 87 "autograph", globals(), 88 "tensorflow.python.autograph.impl.api") 89 90ops.NotDifferentiable("ReduceDataset") 91 92# A constant that can be used to enable auto-tuning. 93AUTOTUNE = -1 94tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE") 95 96 97@tf_export("data.Dataset", v1=[]) 98@six.add_metaclass(abc.ABCMeta) 99class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): 100 """Represents a potentially large set of elements. 101 102 The `tf.data.Dataset` API supports writing descriptive and efficient input 103 pipelines. `Dataset` usage follows a common pattern: 104 105 1. Create a source dataset from your input data. 106 2. Apply dataset transformations to preprocess the data. 107 3. Iterate over the dataset and process the elements. 108 109 Iteration happens in a streaming fashion, so the full dataset does not need to 110 fit into memory. 111 112 Source Datasets: 113 114 The simplest way to create a dataset is to create it from a python `list`: 115 116 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 117 >>> for element in dataset: 118 ... print(element) 119 tf.Tensor(1, shape=(), dtype=int32) 120 tf.Tensor(2, shape=(), dtype=int32) 121 tf.Tensor(3, shape=(), dtype=int32) 122 123 To process lines from files, use `tf.data.TextLineDataset`: 124 125 >>> dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"]) 126 127 To process records written in the `TFRecord` format, use `TFRecordDataset`: 128 129 >>> dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"]) 130 131 To create a dataset of all files matching a pattern, use 132 `tf.data.Dataset.list_files`: 133 134 >>> dataset = tf.data.dataset.list_files("/path/*.txt") # doctest: +SKIP 135 136 See `tf.data.FixedLengthRecordDataset` and `tf.data.Dataset.from_generator` 137 for more ways to create datasets. 138 139 Transformations: 140 141 Once you have a dataset, you can apply transformations to prepare the data for 142 your model: 143 144 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 145 >>> dataset = dataset.map(lambda x: x*2) 146 >>> list(dataset.as_numpy_iterator()) 147 [2, 4, 6] 148 149 Common Terms: 150 151 **Element**: A single output from calling `next()` on a dataset iterator. 152 Elements may be nested structures containing multiple components. For 153 example, the element `(1, (3, "apple"))` has one tuple nested in another 154 tuple. The components are `1`, `3`, and `"apple"`. 155 **Component**: The leaf in the nested structure of an element. 156 157 Supported types: 158 159 Elements can be nested structures of tuples, named tuples, and dictionaries. 160 Element components can be of any type representable by `tf.TypeSpec`, 161 including `tf.Tensor`, `tf.data.Dataset`, `tf.SparseTensor`, 162 `tf.RaggedTensor`, and `tf.TensorArray`. 163 164 >>> a = 1 # Integer element 165 >>> b = 2.0 # Float element 166 >>> c = (1, 2) # Tuple element with 2 components 167 >>> d = {"a": (2, 2), "b": 3} # Dict element with 3 components 168 >>> Point = collections.namedtuple("Point", ["x", "y"]) # doctest: +SKIP 169 >>> e = Point(1, 2) # Named tuple # doctest: +SKIP 170 >>> f = tf.data.Dataset.range(10) # Dataset element 171 172 """ 173 174 def __init__(self, variant_tensor): 175 """Creates a DatasetV2 object. 176 177 This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not 178 take anything in its constructor whereas in the DatasetV2, we expect 179 subclasses to create a variant_tensor and pass it in to the super() call. 180 181 Args: 182 variant_tensor: A DT_VARIANT tensor that represents the dataset. 183 """ 184 self._variant_tensor_attr = variant_tensor 185 weak_self = weakref.proxy(self) 186 self._variant_tracker = self._track_trackable( 187 _VariantTracker( 188 self._variant_tensor, 189 # _trace_variant_creation only works when executing eagerly, so we 190 # don't want to run it immediately. We also want the _VariantTracker 191 # to have a weak reference to the Dataset to avoid creating 192 # reference cycles and making work for the garbage collector. 193 lambda: weak_self._trace_variant_creation()()), # pylint: disable=unnecessary-lambda,protected-access 194 name="_variant_tracker") 195 self._graph_attr = ops.get_default_graph() 196 197 @property 198 def _variant_tensor(self): 199 return self._variant_tensor_attr 200 201 @_variant_tensor.setter 202 def _variant_tensor(self, _): 203 raise ValueError("The _variant_tensor property is read-only") 204 205 @deprecation.deprecated_args(None, "Use external_state_policy instead", 206 "allow_stateful") 207 def _as_serialized_graph( 208 self, 209 allow_stateful=None, 210 strip_device_assignment=None, 211 external_state_policy=distribute_options.ExternalStatePolicy.WARN): 212 """Produces serialized graph representation of the dataset. 213 214 Args: 215 allow_stateful: If true, we allow stateful ops to be present in the graph 216 def. In that case, the state in these ops would be thrown away. 217 strip_device_assignment: If true, non-local (i.e. job and task) device 218 assignment is stripped from ops in the serialized graph. 219 external_state_policy: The ExternalStatePolicy enum that determines how we 220 handle input pipelines that depend on external state. By default, its 221 set to WARN. 222 223 Returns: 224 A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a 225 serialized graph. 226 """ 227 if external_state_policy: 228 policy = None 229 if external_state_policy: 230 policy = external_state_policy.value 231 return gen_dataset_ops.dataset_to_graph_v2( 232 self._variant_tensor, 233 external_state_policy=policy, 234 strip_device_assignment=strip_device_assignment) 235 if strip_device_assignment: 236 return gen_dataset_ops.dataset_to_graph( 237 self._variant_tensor, 238 allow_stateful=allow_stateful, 239 strip_device_assignment=strip_device_assignment) 240 return gen_dataset_ops.dataset_to_graph( 241 self._variant_tensor, allow_stateful=allow_stateful) 242 243 def _trace_variant_creation(self): 244 """Traces a function which outputs a variant `tf.Tensor` for this dataset. 245 246 Note that creating this function involves evaluating an op, and is currently 247 only supported when executing eagerly. 248 249 Returns: 250 A zero-argument `ConcreteFunction` which outputs a variant `tf.Tensor`. 251 """ 252 variant = self._variant_tensor 253 if not isinstance(variant, ops.EagerTensor): 254 raise NotImplementedError( 255 "Can only export Datasets which were created executing eagerly. " 256 "Please file a feature request if this is important to you.") 257 with context.eager_mode(), ops.device("CPU"): 258 # pylint: disable=protected-access 259 graph_def = graph_pb2.GraphDef().FromString( 260 self._as_serialized_graph(external_state_policy=distribute_options 261 .ExternalStatePolicy.FAIL).numpy()) 262 output_node_name = None 263 for node in graph_def.node: 264 if node.op == "_Retval": 265 if output_node_name is not None: 266 raise AssertionError( 267 "Found multiple return values from the dataset's graph, expected " 268 "only one.") 269 output_node_name, = node.input 270 if output_node_name is None: 271 raise AssertionError("Could not find the dataset's output node.") 272 # Add functions used in this Dataset to the function's graph, since they 273 # need to follow it around (and for example be added to a SavedModel which 274 # references the dataset). 275 variant_function = wrap_function.function_from_graph_def( 276 graph_def, inputs=[], outputs=output_node_name + ":0") 277 for used_function in self._functions(): 278 used_function.function.add_to_graph(variant_function.graph) 279 return variant_function 280 281 @abc.abstractmethod 282 def _inputs(self): 283 """Returns a list of the input datasets of the dataset.""" 284 285 raise NotImplementedError("Dataset._inputs") 286 287 @property 288 def _graph(self): 289 return self._graph_attr 290 291 @_graph.setter 292 def _graph(self, _): 293 raise ValueError("The _graph property is read-only") 294 295 def _has_captured_ref(self): 296 """Whether this dataset uses a function that captures ref variables. 297 298 Returns: 299 A boolean, which if true indicates that the dataset or one of its inputs 300 uses a function that captures ref variables. 301 """ 302 if context.executing_eagerly(): 303 # RefVariables are not supported in eager mode 304 return False 305 306 def is_tensor_or_parent_ref(tensor): 307 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access 308 return True 309 # If the captured tensor is an eager tensor, we cannot trace its inputs. 310 if isinstance(tensor, ops._EagerTensorBase): # pylint: disable=protected-access 311 return False 312 return any(is_tensor_or_parent_ref(x) for x in tensor.op.inputs) 313 314 for fn in self._functions(): 315 if any(is_tensor_or_parent_ref(t) for t in fn.function.captured_inputs): 316 return True 317 318 return any( 319 [input_dataset._has_captured_ref() for input_dataset in self._inputs()]) # pylint: disable=protected-access 320 321 # TODO(jsimsa): Change this to be the transitive closure of functions used 322 # by this dataset and its inputs. 323 def _functions(self): 324 """Returns a list of functions associated with this dataset. 325 326 Returns: 327 A list of `StructuredFunctionWrapper` objects. 328 """ 329 return [] 330 331 def options(self): 332 """Returns the options for this dataset and its inputs. 333 334 Returns: 335 A `tf.data.Options` object representing the dataset options. 336 """ 337 options = Options() 338 for input_dataset in self._inputs(): 339 input_options = input_dataset.options() 340 if input_options is not None: 341 options = options.merge(input_options) 342 return options 343 344 def _apply_options(self): 345 """Apply options, such as optimization configuration, to the dataset.""" 346 347 dataset = self 348 options = self.options() 349 350 # (1) Apply threading options 351 if options.experimental_threading is not None: 352 t_options = options.experimental_threading 353 if t_options.max_intra_op_parallelism is not None: 354 dataset = _MaxIntraOpParallelismDataset( 355 dataset, t_options.max_intra_op_parallelism) 356 if t_options.private_threadpool_size is not None: 357 dataset = _PrivateThreadPoolDataset(dataset, 358 t_options.private_threadpool_size) 359 360 # (2) Apply graph rewrite options 361 # pylint: disable=protected-access 362 graph_rewrites = options._graph_rewrites() 363 graph_rewrite_configs = options._graph_rewrite_configs() 364 # pylint: enable=protected-access 365 if graph_rewrites: 366 if self._has_captured_ref(): 367 warnings.warn( 368 "tf.data graph rewrites are not compatible with tf.Variable. " 369 "The following rewrites will be disabled: %s. To enable " 370 "rewrites, use resource variables instead by calling " 371 "`tf.enable_resource_variables()` at the start of the program." % 372 ", ".join(graph_rewrites)) 373 else: 374 dataset = _OptimizeDataset(dataset, graph_rewrites, 375 graph_rewrite_configs) 376 377 # (3) Apply autotune options 378 autotune, algorithm, cpu_budget = options._autotune_settings() # pylint: disable=protected-access 379 380 if autotune: 381 dataset = _ModelDataset(dataset, algorithm, cpu_budget) 382 383 # (4) Apply stats aggregator options 384 if options.experimental_stats and options.experimental_stats.aggregator: # pylint: disable=line-too-long 385 dataset = _SetStatsAggregatorDataset( # pylint: disable=protected-access 386 dataset, options.experimental_stats.aggregator, 387 options.experimental_stats.prefix, 388 options.experimental_stats.counter_prefix) 389 return dataset 390 391 def __iter__(self): 392 """Creates an `Iterator` for enumerating the elements of this dataset. 393 394 The returned iterator implements the Python iterator protocol and therefore 395 can only be used in eager mode. 396 397 Returns: 398 An `Iterator` over the elements of this dataset. 399 400 Raises: 401 RuntimeError: If not inside of tf.function and not executing eagerly. 402 """ 403 if (context.executing_eagerly() 404 or ops.get_default_graph()._building_function): # pylint: disable=protected-access 405 return iterator_ops.OwnedIterator(self) 406 else: 407 raise RuntimeError("__iter__() is only supported inside of tf.function " 408 "or when eager execution is enabled.") 409 410 @abc.abstractproperty 411 def element_spec(self): 412 """The type specification of an element of this dataset. 413 414 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).element_spec 415 TensorSpec(shape=(), dtype=tf.int32, name=None) 416 417 Returns: 418 A nested structure of `tf.TypeSpec` objects matching the structure of an 419 element of this dataset and specifying the type of individual components. 420 """ 421 raise NotImplementedError("Dataset.element_spec") 422 423 def __repr__(self): 424 output_shapes = nest.map_structure(str, get_legacy_output_shapes(self)) 425 output_shapes = str(output_shapes).replace("'", "") 426 output_types = nest.map_structure(repr, get_legacy_output_types(self)) 427 output_types = str(output_types).replace("'", "") 428 return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes, 429 output_types)) 430 431 def as_numpy_iterator(self): 432 """Returns an iterator which converts all elements of the dataset to numpy. 433 434 Use `as_numpy_iterator` to inspect the content of your dataset. To see 435 element shapes and types, print dataset elements directly instead of using 436 `as_numpy_iterator`. 437 438 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 439 >>> for element in dataset: 440 ... print(element) 441 tf.Tensor(1, shape=(), dtype=int32) 442 tf.Tensor(2, shape=(), dtype=int32) 443 tf.Tensor(3, shape=(), dtype=int32) 444 445 This method requires that you are running in eager mode and the dataset's 446 element_spec contains only `TensorSpec` components. 447 448 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 449 >>> for element in dataset.as_numpy_iterator(): 450 ... print(element) 451 1 452 2 453 3 454 455 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 456 >>> print(list(dataset.as_numpy_iterator())) 457 [1, 2, 3] 458 459 `as_numpy_iterator()` will preserve the nested structure of dataset 460 elements. 461 462 >>> dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]), 463 ... 'b': [5, 6]}) 464 >>> list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5}, 465 ... {'a': (2, 4), 'b': 6}] 466 True 467 468 Returns: 469 An iterable over the elements of the dataset, with their tensors converted 470 to numpy arrays. 471 472 Raises: 473 TypeError: if an element contains a non-`Tensor` value. 474 RuntimeError: if eager execution is not enabled. 475 """ 476 if not context.executing_eagerly(): 477 raise RuntimeError("as_numpy_iterator() is not supported while tracing " 478 "functions") 479 for component_spec in nest.flatten(self.element_spec): 480 if not isinstance(component_spec, tensor_spec.TensorSpec): 481 raise TypeError( 482 "Dataset.as_numpy_iterator() does not support datasets containing " 483 + str(component_spec.value_type)) 484 485 return _NumpyIterator(self) 486 487 @property 488 def _flat_shapes(self): 489 """Returns a list `tf.TensorShapes`s for the element tensor representation. 490 491 Returns: 492 A list `tf.TensorShapes`s for the element tensor representation. 493 """ 494 return structure.get_flat_tensor_shapes(self.element_spec) 495 496 @property 497 def _flat_types(self): 498 """Returns a list `tf.DType`s for the element tensor representation. 499 500 Returns: 501 A list `tf.DType`s for the element tensor representation. 502 """ 503 return structure.get_flat_tensor_types(self.element_spec) 504 505 @property 506 def _flat_structure(self): 507 """Helper for setting `output_shapes` and `output_types` attrs of an op. 508 509 Most dataset op constructors expect `output_shapes` and `output_types` 510 arguments that represent the flattened structure of an element. This helper 511 function generates these attrs as a keyword argument dictionary, allowing 512 `Dataset._variant_tensor` implementations to pass `**self._flat_structure` 513 to the op constructor. 514 515 Returns: 516 A dictionary of keyword arguments that can be passed to a dataset op 517 constructor. 518 """ 519 return { 520 "output_shapes": self._flat_shapes, 521 "output_types": self._flat_types, 522 } 523 524 @property 525 def _type_spec(self): 526 return DatasetSpec(self.element_spec) 527 528 @staticmethod 529 def from_tensors(tensors): 530 """Creates a `Dataset` with a single element, comprising the given tensors. 531 532 >>> dataset = tf.data.Dataset.from_tensors([1, 2, 3]) 533 >>> list(dataset.as_numpy_iterator()) 534 [array([1, 2, 3], dtype=int32)] 535 >>> dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A')) 536 >>> list(dataset.as_numpy_iterator()) 537 [(array([1, 2, 3], dtype=int32), b'A')] 538 539 Note that if `tensors` contains a NumPy array, and eager execution is not 540 enabled, the values will be embedded in the graph as one or more 541 `tf.constant` operations. For large datasets (> 1 GB), this can waste 542 memory and run into byte limits of graph serialization. If `tensors` 543 contains one or more large NumPy arrays, consider the alternative described 544 in [this 545 guide](https://tensorflow.org/guide/data#consuming_numpy_arrays). 546 547 Args: 548 tensors: A dataset element. 549 550 Returns: 551 Dataset: A `Dataset`. 552 """ 553 return TensorDataset(tensors) 554 555 @staticmethod 556 def from_tensor_slices(tensors): 557 """Creates a `Dataset` whose elements are slices of the given tensors. 558 559 The given tensors are sliced along their first dimension. This operation 560 preserves the structure of the input tensors, removing the first dimension 561 of each tensor and using it as the dataset dimension. All input tensors 562 must have the same size in their first dimensions. 563 564 >>> # Slicing a 1D tensor produces scalar tensor elements. 565 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 566 >>> list(dataset.as_numpy_iterator()) 567 [1, 2, 3] 568 569 >>> # Slicing a 2D tensor produces 1D tensor elements. 570 >>> dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]]) 571 >>> list(dataset.as_numpy_iterator()) 572 [array([1, 2], dtype=int32), array([3, 4], dtype=int32)] 573 574 >>> # Slicing a tuple of 1D tensors produces tuple elements containing 575 >>> # scalar tensors. 576 >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6])) 577 >>> list(dataset.as_numpy_iterator()) 578 [(1, 3, 5), (2, 4, 6)] 579 580 >>> # Dictionary structure is also preserved. 581 >>> dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]}) 582 >>> list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3}, 583 ... {'a': 2, 'b': 4}] 584 True 585 586 >>> # Two tensors can be combined into one Dataset object. 587 >>> features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor 588 >>> labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor 589 >>> dataset = Dataset.from_tensor_slices((features, labels)) 590 >>> # Both the features and the labels tensors can be converted 591 >>> # to a Dataset object separately and combined after. 592 >>> features_dataset = Dataset.from_tensor_slices(features) 593 >>> labels_dataset = Dataset.from_tensor_slices(labels) 594 >>> dataset = Dataset.zip((features_dataset, labels_dataset)) 595 >>> # A batched feature and label set can be converted to a Dataset 596 >>> # in similar fashion. 597 >>> batched_features = tf.constant([[[1, 3], [2, 3]], 598 ... [[2, 1], [1, 2]], 599 ... [[3, 3], [3, 2]]], shape=(3, 2, 2)) 600 >>> batched_labels = tf.constant([['A', 'A'], 601 ... ['B', 'B'], 602 ... ['A', 'B']], shape=(3, 2, 1)) 603 >>> dataset = Dataset.from_tensor_slices((batched_features, batched_labels)) 604 >>> for element in dataset.as_numpy_iterator(): 605 ... print(element) 606 (array([[1, 3], 607 [2, 3]], dtype=int32), array([[b'A'], 608 [b'A']], dtype=object)) 609 (array([[2, 1], 610 [1, 2]], dtype=int32), array([[b'B'], 611 [b'B']], dtype=object)) 612 (array([[3, 3], 613 [3, 2]], dtype=int32), array([[b'A'], 614 [b'B']], dtype=object)) 615 616 Note that if `tensors` contains a NumPy array, and eager execution is not 617 enabled, the values will be embedded in the graph as one or more 618 `tf.constant` operations. For large datasets (> 1 GB), this can waste 619 memory and run into byte limits of graph serialization. If `tensors` 620 contains one or more large NumPy arrays, consider the alternative described 621 in [this guide]( 622 https://tensorflow.org/guide/data#consuming_numpy_arrays). 623 624 Args: 625 tensors: A dataset element, with each component having the same size in 626 the first dimension. 627 628 Returns: 629 Dataset: A `Dataset`. 630 """ 631 return TensorSliceDataset(tensors) 632 633 class _GeneratorState(object): 634 """Stores outstanding iterators created from a Python generator. 635 636 This class keeps track of potentially multiple iterators that may have 637 been created from a generator, e.g. in the case that the dataset is 638 repeated, or nested within a parallel computation. 639 """ 640 641 def __init__(self, generator): 642 self._generator = generator 643 self._lock = threading.Lock() 644 self._next_id = 0 # GUARDED_BY(self._lock) 645 self._args = {} 646 self._iterators = {} 647 648 def get_next_id(self, *args): 649 with self._lock: 650 ret = self._next_id 651 self._next_id += 1 652 self._args[ret] = args 653 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit 654 # casting in `py_func()` will create an array of `np.int32` on Windows, 655 # leading to a runtime error. 656 return np.array(ret, dtype=np.int64) 657 658 def get_iterator(self, iterator_id): 659 try: 660 return self._iterators[iterator_id] 661 except KeyError: 662 iterator = iter(self._generator(*self._args.pop(iterator_id))) 663 self._iterators[iterator_id] = iterator 664 return iterator 665 666 def iterator_completed(self, iterator_id): 667 del self._iterators[iterator_id] 668 669 @staticmethod 670 def from_generator(generator, output_types, output_shapes=None, args=None): 671 """Creates a `Dataset` whose elements are generated by `generator`. 672 673 The `generator` argument must be a callable object that returns 674 an object that supports the `iter()` protocol (e.g. a generator function). 675 The elements generated by `generator` must be compatible with the given 676 `output_types` and (optional) `output_shapes` arguments. 677 678 >>> import itertools 679 >>> 680 >>> def gen(): 681 ... for i in itertools.count(1): 682 ... yield (i, [1] * i) 683 >>> 684 >>> dataset = tf.data.Dataset.from_generator( 685 ... gen, 686 ... (tf.int64, tf.int64), 687 ... (tf.TensorShape([]), tf.TensorShape([None]))) 688 >>> 689 >>> list(dataset.take(3).as_numpy_iterator()) 690 [(1, array([1])), (2, array([1, 1])), (3, array([1, 1, 1]))] 691 692 NOTE: The current implementation of `Dataset.from_generator()` uses 693 `tf.numpy_function` and inherits the same constraints. In particular, it 694 requires the `Dataset`- and `Iterator`-related operations to be placed 695 on a device in the same process as the Python program that called 696 `Dataset.from_generator()`. The body of `generator` will not be 697 serialized in a `GraphDef`, and you should not use this method if you 698 need to serialize your model and restore it in a different environment. 699 700 NOTE: If `generator` depends on mutable global variables or other external 701 state, be aware that the runtime may invoke `generator` multiple times 702 (in order to support repeating the `Dataset`) and at any time 703 between the call to `Dataset.from_generator()` and the production of the 704 first element from the generator. Mutating global variables or external 705 state can cause undefined behavior, and we recommend that you explicitly 706 cache any external state in `generator` before calling 707 `Dataset.from_generator()`. 708 709 Args: 710 generator: A callable object that returns an object that supports the 711 `iter()` protocol. If `args` is not specified, `generator` must take no 712 arguments; otherwise it must take as many arguments as there are values 713 in `args`. 714 output_types: A nested structure of `tf.DType` objects corresponding to 715 each component of an element yielded by `generator`. 716 output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects 717 corresponding to each component of an element yielded by `generator`. 718 args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated 719 and passed to `generator` as NumPy-array arguments. 720 721 Returns: 722 Dataset: A `Dataset`. 723 """ 724 if not callable(generator): 725 raise TypeError("`generator` must be callable.") 726 if output_shapes is None: 727 output_shapes = nest.map_structure( 728 lambda _: tensor_shape.TensorShape(None), output_types) 729 else: 730 output_shapes = nest.map_structure_up_to( 731 output_types, tensor_shape.as_shape, output_shapes) 732 if args is None: 733 args = () 734 else: 735 args = tuple(ops.convert_n_to_tensor(args, name="args")) 736 737 flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)] 738 flattened_shapes = nest.flatten(output_shapes) 739 740 generator_state = DatasetV2._GeneratorState(generator) 741 742 def get_iterator_id_fn(unused_dummy): 743 """Creates a unique `iterator_id` for each pass over the dataset. 744 745 The returned `iterator_id` disambiguates between multiple concurrently 746 existing iterators. 747 748 Args: 749 unused_dummy: Ignored value. 750 751 Returns: 752 A `tf.int64` tensor whose value uniquely identifies an iterator in 753 `generator_state`. 754 """ 755 return script_ops.numpy_function(generator_state.get_next_id, args, 756 dtypes.int64) 757 758 def generator_next_fn(iterator_id_t): 759 """Generates the next element from iterator with ID `iterator_id_t`. 760 761 We map this function across an infinite repetition of the 762 `iterator_id_t`, and raise `StopIteration` to terminate the iteration. 763 764 Args: 765 iterator_id_t: A `tf.int64` tensor whose value uniquely identifies the 766 iterator in `generator_state` from which to generate an element. 767 768 Returns: 769 The next element to generate from the iterator. 770 """ 771 772 def generator_py_func(iterator_id): 773 """A `py_func` that will be called to invoke the iterator.""" 774 # `next()` raises `StopIteration` when there are no more 775 # elements remaining to be generated. 776 values = next(generator_state.get_iterator(iterator_id)) 777 778 # Use the same _convert function from the py_func() implementation to 779 # convert the returned values to arrays early, so that we can inspect 780 # their values. 781 try: 782 flattened_values = nest.flatten_up_to(output_types, values) 783 except (TypeError, ValueError): 784 six.reraise(TypeError, TypeError( 785 "`generator` yielded an element that did not match the expected " 786 "structure. The expected structure was %s, but the yielded " 787 "element was %s." % (output_types, values)), sys.exc_info()[2]) 788 ret_arrays = [] 789 for ret, dtype in zip(flattened_values, flattened_types): 790 try: 791 ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access 792 ret, dtype=dtype.as_numpy_dtype)) 793 except (TypeError, ValueError): 794 six.reraise(TypeError, TypeError( 795 "`generator` yielded an element that could not be converted to " 796 "the expected type. The expected type was %s, but the yielded " 797 "element was %s." % (dtype.name, ret)), sys.exc_info()[2]) 798 799 # Additional type and shape checking to ensure that the components 800 # of the generated element match the `output_types` and `output_shapes` 801 # arguments. 802 for (ret_array, expected_dtype, expected_shape) in zip( 803 ret_arrays, flattened_types, flattened_shapes): 804 if ret_array.dtype != expected_dtype.as_numpy_dtype: 805 raise TypeError( 806 "`generator` yielded an element of type %s where an element " 807 "of type %s was expected." % (ret_array.dtype, 808 expected_dtype.as_numpy_dtype)) 809 if not expected_shape.is_compatible_with(ret_array.shape): 810 raise ValueError( 811 "`generator` yielded an element of shape %s where an element " 812 "of shape %s was expected." % (ret_array.shape, expected_shape)) 813 814 return ret_arrays 815 816 flat_values = script_ops.numpy_function(generator_py_func, 817 [iterator_id_t], flattened_types) 818 819 # The `py_func()` op drops the inferred shapes, so we add them back in 820 # here. 821 if output_shapes is not None: 822 for ret_t, shape in zip(flat_values, flattened_shapes): 823 ret_t.set_shape(shape) 824 825 return nest.pack_sequence_as(output_types, flat_values) 826 827 def finalize_fn(iterator_id_t): 828 """Releases host-side state for the iterator with ID `iterator_id_t`.""" 829 830 def finalize_py_func(iterator_id): 831 generator_state.iterator_completed(iterator_id) 832 # We return a dummy value so that the `finalize_fn` has a valid 833 # signature. 834 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit 835 # casting in `py_func()` will create an array of `np.int32` on Windows, 836 # leading to a runtime error. 837 return np.array(0, dtype=np.int64) 838 839 return script_ops.numpy_function(finalize_py_func, [iterator_id_t], 840 dtypes.int64) 841 842 # This function associates each traversal of `generator` with a unique 843 # iterator ID. 844 def flat_map_fn(dummy_arg): 845 # The `get_iterator_id_fn` gets a unique ID for the current instance of 846 # of the generator. 847 # The `generator_next_fn` gets the next element from the iterator with the 848 # given ID, and raises StopIteration when that iterator contains no 849 # more elements. 850 return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn, 851 finalize_fn) 852 853 # A single-element dataset that, each time it is evaluated, contains a 854 # freshly-generated and unique (for the returned dataset) int64 855 # ID that will be used to identify the appropriate Python state, which 856 # is encapsulated in `generator_state`, and captured in 857 # `get_iterator_id_map_fn`. 858 dummy = 0 859 id_dataset = Dataset.from_tensors(dummy) 860 861 # A dataset that contains all of the elements generated by a 862 # single iterator created from `generator`, identified by the 863 # iterator ID contained in `id_dataset`. Lifting the iteration 864 # into a flat_map here enables multiple repetitions and/or nested 865 # versions of the returned dataset to be created, because it forces 866 # the generation of a new ID for each version. 867 return id_dataset.flat_map(flat_map_fn) 868 869 @staticmethod 870 def range(*args, **kwargs): 871 """Creates a `Dataset` of a step-separated range of values. 872 873 >>> list(Dataset.range(5).as_numpy_iterator()) 874 [0, 1, 2, 3, 4] 875 >>> list(Dataset.range(2, 5).as_numpy_iterator()) 876 [2, 3, 4] 877 >>> list(Dataset.range(1, 5, 2).as_numpy_iterator()) 878 [1, 3] 879 >>> list(Dataset.range(1, 5, -2).as_numpy_iterator()) 880 [] 881 >>> list(Dataset.range(5, 1).as_numpy_iterator()) 882 [] 883 >>> list(Dataset.range(5, 1, -2).as_numpy_iterator()) 884 [5, 3] 885 >>> list(Dataset.range(2, 5, output_type=tf.int32).as_numpy_iterator()) 886 [2, 3, 4] 887 >>> list(Dataset.range(1, 5, 2, output_type=tf.float32).as_numpy_iterator()) 888 [1.0, 3.0] 889 890 Args: 891 *args: follows the same semantics as python's xrange. 892 len(args) == 1 -> start = 0, stop = args[0], step = 1 893 len(args) == 2 -> start = args[0], stop = args[1], step = 1 894 len(args) == 3 -> start = args[0], stop = args[1, stop = args[2] 895 **kwargs: 896 - output_type: Its expected dtype. (Optional, default: `tf.int64`). 897 898 Returns: 899 Dataset: A `RangeDataset`. 900 901 Raises: 902 ValueError: if len(args) == 0. 903 """ 904 return RangeDataset(*args, **kwargs) 905 906 @staticmethod 907 def zip(datasets): 908 """Creates a `Dataset` by zipping together the given datasets. 909 910 This method has similar semantics to the built-in `zip()` function 911 in Python, with the main difference being that the `datasets` 912 argument can be an arbitrary nested structure of `Dataset` objects. 913 914 >>> # The nested structure of the `datasets` argument determines the 915 >>> # structure of elements in the resulting dataset. 916 >>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ] 917 >>> b = tf.data.Dataset.range(4, 7) # ==> [ 4, 5, 6 ] 918 >>> ds = tf.data.Dataset.zip((a, b)) 919 >>> list(ds.as_numpy_iterator()) 920 [(1, 4), (2, 5), (3, 6)] 921 >>> ds = tf.data.Dataset.zip((b, a)) 922 >>> list(ds.as_numpy_iterator()) 923 [(4, 1), (5, 2), (6, 3)] 924 >>> 925 >>> # The `datasets` argument may contain an arbitrary number of datasets. 926 >>> c = tf.data.Dataset.range(7, 13).batch(2) # ==> [ [7, 8], 927 ... # [9, 10], 928 ... # [11, 12] ] 929 >>> ds = tf.data.Dataset.zip((a, b, c)) 930 >>> for element in ds.as_numpy_iterator(): 931 ... print(element) 932 (1, 4, array([7, 8])) 933 (2, 5, array([ 9, 10])) 934 (3, 6, array([11, 12])) 935 >>> 936 >>> # The number of elements in the resulting dataset is the same as 937 >>> # the size of the smallest dataset in `datasets`. 938 >>> d = tf.data.Dataset.range(13, 15) # ==> [ 13, 14 ] 939 >>> ds = tf.data.Dataset.zip((a, d)) 940 >>> list(ds.as_numpy_iterator()) 941 [(1, 13), (2, 14)] 942 943 Args: 944 datasets: A nested structure of datasets. 945 946 Returns: 947 Dataset: A `Dataset`. 948 """ 949 return ZipDataset(datasets) 950 951 def concatenate(self, dataset): 952 """Creates a `Dataset` by concatenating the given dataset with this dataset. 953 954 >>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ] 955 >>> b = tf.data.Dataset.range(4, 8) # ==> [ 4, 5, 6, 7 ] 956 >>> ds = a.concatenate(b) 957 >>> list(ds.as_numpy_iterator()) 958 [1, 2, 3, 4, 5, 6, 7] 959 >>> # The input dataset and dataset to be concatenated should have the same 960 >>> # nested structures and output types. 961 >>> c = tf.data.Dataset.zip((a, b)) 962 >>> a.concatenate(c) 963 Traceback (most recent call last): 964 TypeError: Two datasets to concatenate have different types 965 <dtype: 'int64'> and (tf.int64, tf.int64) 966 >>> d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"]) 967 >>> a.concatenate(d) 968 Traceback (most recent call last): 969 TypeError: Two datasets to concatenate have different types 970 <dtype: 'int64'> and <dtype: 'string'> 971 972 Args: 973 dataset: `Dataset` to be concatenated. 974 975 Returns: 976 Dataset: A `Dataset`. 977 """ 978 return ConcatenateDataset(self, dataset) 979 980 def prefetch(self, buffer_size): 981 """Creates a `Dataset` that prefetches elements from this dataset. 982 983 Most dataset input pipelines should end with a call to `prefetch`. This 984 allows later elements to be prepared while the current element is being 985 processed. This often improves latency and throughput, at the cost of 986 using additional memory to store prefetched elements. 987 988 Note: Like other `Dataset` methods, prefetch operates on the 989 elements of the input dataset. It has no concept of examples vs. batches. 990 `examples.prefetch(2)` will prefetch two elements (2 examples), 991 while `examples.batch(20).prefetch(2)` will prefetch 2 elements 992 (2 batches, of 20 examples each). 993 994 >>> dataset = tf.data.Dataset.range(3) 995 >>> dataset = dataset.prefetch(2) 996 >>> list(dataset.as_numpy_iterator()) 997 [0, 1, 2] 998 999 Args: 1000 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the maximum 1001 number of elements that will be buffered when prefetching. 1002 1003 Returns: 1004 Dataset: A `Dataset`. 1005 """ 1006 return PrefetchDataset(self, buffer_size) 1007 1008 @staticmethod 1009 def list_files(file_pattern, shuffle=None, seed=None): 1010 """A dataset of all files matching one or more glob patterns. 1011 1012 The `file_pattern` argument should be a small number of glob patterns. 1013 If your filenames have already been globbed, use 1014 `Dataset.from_tensor_slices(filenames)` instead, as re-globbing every 1015 filename with `list_files` may result in poor performance with remote 1016 storage systems. 1017 1018 NOTE: The default behavior of this method is to return filenames in 1019 a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False` 1020 to get results in a deterministic order. 1021 1022 Example: 1023 If we had the following files on our filesystem: 1024 - /path/to/dir/a.txt 1025 - /path/to/dir/b.py 1026 - /path/to/dir/c.py 1027 If we pass "/path/to/dir/*.py" as the directory, the dataset 1028 would produce: 1029 - /path/to/dir/b.py 1030 - /path/to/dir/c.py 1031 1032 Args: 1033 file_pattern: A string, a list of strings, or a `tf.Tensor` of string type 1034 (scalar or vector), representing the filename glob (i.e. shell wildcard) 1035 pattern(s) that will be matched. 1036 shuffle: (Optional.) If `True`, the file names will be shuffled randomly. 1037 Defaults to `True`. 1038 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 1039 seed that will be used to create the distribution. See 1040 `tf.random.set_seed` for behavior. 1041 1042 Returns: 1043 Dataset: A `Dataset` of strings corresponding to file names. 1044 """ 1045 with ops.name_scope("list_files"): 1046 if shuffle is None: 1047 shuffle = True 1048 file_pattern = ops.convert_to_tensor( 1049 file_pattern, dtype=dtypes.string, name="file_pattern") 1050 matching_files = gen_io_ops.matching_files(file_pattern) 1051 1052 # Raise an exception if `file_pattern` does not match any files. 1053 condition = math_ops.greater(array_ops.shape(matching_files)[0], 0, 1054 name="match_not_empty") 1055 1056 message = math_ops.add( 1057 "No files matched pattern: ", 1058 string_ops.reduce_join(file_pattern, separator=", "), name="message") 1059 1060 assert_not_empty = control_flow_ops.Assert( 1061 condition, [message], summarize=1, name="assert_not_empty") 1062 with ops.control_dependencies([assert_not_empty]): 1063 matching_files = array_ops.identity(matching_files) 1064 1065 dataset = Dataset.from_tensor_slices(matching_files) 1066 if shuffle: 1067 # NOTE(mrry): The shuffle buffer size must be greater than zero, but the 1068 # list of files might be empty. 1069 buffer_size = math_ops.maximum( 1070 array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1) 1071 dataset = dataset.shuffle(buffer_size, seed=seed) 1072 return dataset 1073 1074 def repeat(self, count=None): 1075 """Repeats this dataset so each original value is seen `count` times. 1076 1077 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 1078 >>> dataset = dataset.repeat(3) 1079 >>> list(dataset.as_numpy_iterator()) 1080 [1, 2, 3, 1, 2, 3, 1, 2, 3] 1081 1082 NOTE: If this dataset is a function of global state (e.g. a random number 1083 generator), then different repetitions may produce different elements. 1084 1085 Args: 1086 count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 1087 number of times the dataset should be repeated. The default behavior (if 1088 `count` is `None` or `-1`) is for the dataset be repeated indefinitely. 1089 1090 Returns: 1091 Dataset: A `Dataset`. 1092 """ 1093 return RepeatDataset(self, count) 1094 1095 def enumerate(self, start=0): 1096 """Enumerates the elements of this dataset. 1097 1098 It is similar to python's `enumerate`. 1099 1100 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 1101 >>> dataset = dataset.enumerate(start=5) 1102 >>> for element in dataset.as_numpy_iterator(): 1103 ... print(element) 1104 (5, 1) 1105 (6, 2) 1106 (7, 3) 1107 1108 >>> # The nested structure of the input dataset determines the structure of 1109 >>> # elements in the resulting dataset. 1110 >>> dataset = tf.data.Dataset.from_tensor_slices([(7, 8), (9, 10)]) 1111 >>> dataset = dataset.enumerate() 1112 >>> for element in dataset.as_numpy_iterator(): 1113 ... print(element) 1114 (0, array([7, 8], dtype=int32)) 1115 (1, array([ 9, 10], dtype=int32)) 1116 1117 Args: 1118 start: A `tf.int64` scalar `tf.Tensor`, representing the start value for 1119 enumeration. 1120 1121 Returns: 1122 Dataset: A `Dataset`. 1123 """ 1124 1125 max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max 1126 return Dataset.zip((Dataset.range(start, max_value), self)) 1127 1128 def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None): 1129 """Randomly shuffles the elements of this dataset. 1130 1131 This dataset fills a buffer with `buffer_size` elements, then randomly 1132 samples elements from this buffer, replacing the selected elements with new 1133 elements. For perfect shuffling, a buffer size greater than or equal to the 1134 full size of the dataset is required. 1135 1136 For instance, if your dataset contains 10,000 elements but `buffer_size` is 1137 set to 1,000, then `shuffle` will initially select a random element from 1138 only the first 1,000 elements in the buffer. Once an element is selected, 1139 its space in the buffer is replaced by the next (i.e. 1,001-st) element, 1140 maintaining the 1,000 element buffer. 1141 1142 `reshuffle_each_iteration` controls whether the shuffle order should be 1143 different for each epoch. In TF 1.X, the idiomatic way to create epochs 1144 was through the `repeat` transformation: 1145 1146 >>> dataset = tf.data.Dataset.range(3) 1147 >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 1148 >>> dataset = dataset.repeat(2) # doctest: +SKIP 1149 [1, 0, 2, 1, 2, 0] 1150 1151 >>> dataset = tf.data.Dataset.range(3) 1152 >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=False) 1153 >>> dataset = dataset.repeat(2) # doctest: +SKIP 1154 [1, 0, 2, 1, 0, 2] 1155 1156 In TF 2.0, `tf.data.Dataset` objects are Python iterables which makes it 1157 possible to also create epochs through Python iteration: 1158 1159 >>> dataset = tf.data.Dataset.range(3) 1160 >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 1161 >>> list(dataset.as_numpy_iterator()) # doctest: +SKIP 1162 [1, 0, 2] 1163 >>> list(dataset.as_numpy_iterator()) # doctest: +SKIP 1164 [1, 2, 0] 1165 1166 >>> dataset = tf.data.Dataset.range(3) 1167 >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=False) 1168 >>> list(dataset.as_numpy_iterator()) # doctest: +SKIP 1169 [1, 0, 2] 1170 >>> list(dataset.as_numpy_iterator()) # doctest: +SKIP 1171 [1, 0, 2] 1172 1173 Args: 1174 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 1175 elements from this dataset from which the new dataset will sample. 1176 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 1177 seed that will be used to create the distribution. See 1178 `tf.random.set_seed` for behavior. 1179 reshuffle_each_iteration: (Optional.) A boolean, which if true indicates 1180 that the dataset should be pseudorandomly reshuffled each time it is 1181 iterated over. (Defaults to `True`.) 1182 1183 Returns: 1184 Dataset: A `Dataset`. 1185 """ 1186 return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration) 1187 1188 def cache(self, filename=""): 1189 """Caches the elements in this dataset. 1190 1191 The first time the dataset is iterated over, its elements will be cached 1192 either in the specified file or in memory. Subsequent iterations will 1193 use the cached data. 1194 1195 Note: For the cache to be finalized, the input dataset must be iterated 1196 through in its entirety. Otherwise, subsequent iterations will not use 1197 cached data. 1198 1199 >>> dataset = tf.data.Dataset.range(5) 1200 >>> dataset = dataset.map(lambda x: x**2) 1201 >>> dataset = dataset.cache() 1202 >>> # The first time reading through the data will generate the data using 1203 >>> # `range` and `map`. 1204 >>> list(dataset.as_numpy_iterator()) 1205 [0, 1, 4, 9, 16] 1206 >>> # Subsequent iterations read from the cache. 1207 >>> list(dataset.as_numpy_iterator()) 1208 [0, 1, 4, 9, 16] 1209 1210 When caching to a file, the cached data will persist across runs. Even the 1211 first iteration through the data will read from the cache file. Changing 1212 the input pipeline before the call to `.cache()` will have no effect until 1213 the cache file is removed or the filename is changed. 1214 1215 >>> dataset = tf.data.Dataset.range(5) 1216 >>> dataset = dataset.cache("/path/to/file) # doctest: +SKIP 1217 >>> list(dataset.as_numpy_iterator()) # doctest: +SKIP 1218 [0, 1, 2, 3, 4] 1219 >>> dataset = tf.data.Dataset.range(10) 1220 >>> dataset = dataset.cache("/path/to/file") # Same file! # doctest: +SKIP 1221 >>> list(dataset.as_numpy_iterator()) # doctest: +SKIP 1222 [0, 1, 2, 3, 4] 1223 1224 Note: `cache` will produce exactly the same elements during each iteration 1225 through the dataset. If you wish to randomize the iteration order, make sure 1226 to call `shuffle` *after* calling `cache`. 1227 1228 Args: 1229 filename: A `tf.string` scalar `tf.Tensor`, representing the name of a 1230 directory on the filesystem to use for caching elements in this Dataset. 1231 If a filename is not provided, the dataset will be cached in memory. 1232 1233 Returns: 1234 Dataset: A `Dataset`. 1235 """ 1236 return CacheDataset(self, filename) 1237 1238 def take(self, count): 1239 """Creates a `Dataset` with at most `count` elements from this dataset. 1240 1241 >>> dataset = tf.data.Dataset.range(10) 1242 >>> dataset = dataset.take(3) 1243 >>> list(dataset.as_numpy_iterator()) 1244 [0, 1, 2] 1245 1246 Args: 1247 count: A `tf.int64` scalar `tf.Tensor`, representing the number of 1248 elements of this dataset that should be taken to form the new dataset. 1249 If `count` is -1, or if `count` is greater than the size of this 1250 dataset, the new dataset will contain all elements of this dataset. 1251 1252 Returns: 1253 Dataset: A `Dataset`. 1254 """ 1255 return TakeDataset(self, count) 1256 1257 def skip(self, count): 1258 """Creates a `Dataset` that skips `count` elements from this dataset. 1259 1260 >>> dataset = tf.data.Dataset.range(10) 1261 >>> dataset = dataset.skip(7) 1262 >>> list(dataset.as_numpy_iterator()) 1263 [7, 8, 9] 1264 1265 Args: 1266 count: A `tf.int64` scalar `tf.Tensor`, representing the number of 1267 elements of this dataset that should be skipped to form the new dataset. 1268 If `count` is greater than the size of this dataset, the new dataset 1269 will contain no elements. If `count` is -1, skips the entire dataset. 1270 1271 Returns: 1272 Dataset: A `Dataset`. 1273 """ 1274 return SkipDataset(self, count) 1275 1276 def shard(self, num_shards, index): 1277 """Creates a `Dataset` that includes only 1/`num_shards` of this dataset. 1278 1279 `shard` is deterministic. The Dataset produced by `A.shard(n, i)` will 1280 contain all elements of A whose index mod n = i. 1281 1282 >>> A = tf.data.Dataset.range(10) 1283 >>> B = A.shard(num_shards=3, index=0) 1284 >>> list(B.as_numpy_iterator()) 1285 [0, 3, 6, 9] 1286 >>> C = A.shard(num_shards=3, index=1) 1287 >>> list(C.as_numpy_iterator()) 1288 [1, 4, 7] 1289 >>> D = A.shard(num_shards=3, index=2) 1290 >>> list(D.as_numpy_iterator()) 1291 [2, 5, 8] 1292 1293 This dataset operator is very useful when running distributed training, as 1294 it allows each worker to read a unique subset. 1295 1296 When reading a single input file, you can shard elements as follows: 1297 1298 ```python 1299 d = tf.data.TFRecordDataset(input_file) 1300 d = d.shard(num_workers, worker_index) 1301 d = d.repeat(num_epochs) 1302 d = d.shuffle(shuffle_buffer_size) 1303 d = d.map(parser_fn, num_parallel_calls=num_map_threads) 1304 ``` 1305 1306 Important caveats: 1307 1308 - Be sure to shard before you use any randomizing operator (such as 1309 shuffle). 1310 - Generally it is best if the shard operator is used early in the dataset 1311 pipeline. For example, when reading from a set of TFRecord files, shard 1312 before converting the dataset to input samples. This avoids reading every 1313 file on every worker. The following is an example of an efficient 1314 sharding strategy within a complete pipeline: 1315 1316 ```python 1317 d = Dataset.list_files(pattern) 1318 d = d.shard(num_workers, worker_index) 1319 d = d.repeat(num_epochs) 1320 d = d.shuffle(shuffle_buffer_size) 1321 d = d.interleave(tf.data.TFRecordDataset, 1322 cycle_length=num_readers, block_length=1) 1323 d = d.map(parser_fn, num_parallel_calls=num_map_threads) 1324 ``` 1325 1326 Args: 1327 num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of 1328 shards operating in parallel. 1329 index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. 1330 1331 Returns: 1332 Dataset: A `Dataset`. 1333 1334 Raises: 1335 InvalidArgumentError: if `num_shards` or `index` are illegal values. 1336 Note: error checking is done on a best-effort basis, and errors aren't 1337 guaranteed to be caught upon dataset creation. (e.g. providing in a 1338 placeholder tensor bypasses the early checking, and will instead result 1339 in an error during a session.run call.) 1340 """ 1341 return ShardDataset(self, num_shards, index) 1342 1343 def batch(self, batch_size, drop_remainder=False): 1344 """Combines consecutive elements of this dataset into batches. 1345 1346 >>> dataset = tf.data.Dataset.range(8) 1347 >>> dataset = dataset.batch(3) 1348 >>> list(dataset.as_numpy_iterator()) 1349 [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])] 1350 1351 >>> dataset = tf.data.Dataset.range(8) 1352 >>> dataset = dataset.batch(3, drop_remainder=True) 1353 >>> list(dataset.as_numpy_iterator()) 1354 [array([0, 1, 2]), array([3, 4, 5])] 1355 1356 The components of the resulting element will have an additional outer 1357 dimension, which will be `batch_size` (or `N % batch_size` for the last 1358 element if `batch_size` does not divide the number of input elements `N` 1359 evenly and `drop_remainder` is `False`). If your program depends on the 1360 batches having the same outer dimension, you should set the `drop_remainder` 1361 argument to `True` to prevent the smaller batch from being produced. 1362 1363 Args: 1364 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 1365 consecutive elements of this dataset to combine in a single batch. 1366 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 1367 whether the last batch should be dropped in the case it has fewer than 1368 `batch_size` elements; the default behavior is not to drop the smaller 1369 batch. 1370 1371 Returns: 1372 Dataset: A `Dataset`. 1373 """ 1374 return BatchDataset(self, batch_size, drop_remainder) 1375 1376 def padded_batch(self, 1377 batch_size, 1378 padded_shapes=None, 1379 padding_values=None, 1380 drop_remainder=False): 1381 """Combines consecutive elements of this dataset into padded batches. 1382 1383 This transformation combines multiple consecutive elements of the input 1384 dataset into a single element. 1385 1386 Like `tf.data.Dataset.batch`, the components of the resulting element will 1387 have an additional outer dimension, which will be `batch_size` (or 1388 `N % batch_size` for the last element if `batch_size` does not divide the 1389 number of input elements `N` evenly and `drop_remainder` is `False`). If 1390 your program depends on the batches having the same outer dimension, you 1391 should set the `drop_remainder` argument to `True` to prevent the smaller 1392 batch from being produced. 1393 1394 Unlike `tf.data.Dataset.batch`, the input elements to be batched may have 1395 different shapes, and this transformation will pad each component to the 1396 respective shape in `padded_shapes`. The `padded_shapes` argument 1397 determines the resulting shape for each dimension of each component in an 1398 output element: 1399 1400 * If the dimension is a constant, the component will be padded out to that 1401 length in that dimension. 1402 * If the dimension is unknown, the component will be padded out to the 1403 maximum length of all elements in that dimension. 1404 1405 >>> A = (tf.data.Dataset 1406 ... .range(1, 5, output_type=tf.int32) 1407 ... .map(lambda x: tf.fill([x], x))) 1408 >>> # Pad to the smallest per-batch size that fits all elements. 1409 >>> B = A.padded_batch(2) 1410 >>> for element in B.as_numpy_iterator(): 1411 ... print(element) 1412 [[1 0] 1413 [2 2]] 1414 [[3 3 3 0] 1415 [4 4 4 4]] 1416 >>> # Pad to a fixed size. 1417 >>> C = A.padded_batch(2, padded_shapes=5) 1418 >>> for element in C.as_numpy_iterator(): 1419 ... print(element) 1420 [[1 0 0 0 0] 1421 [2 2 0 0 0]] 1422 [[3 3 3 0 0] 1423 [4 4 4 4 0]] 1424 >>> # Pad with a custom value. 1425 >>> D = A.padded_batch(2, padded_shapes=5, padding_values=-1) 1426 >>> for element in D.as_numpy_iterator(): 1427 ... print(element) 1428 [[ 1 -1 -1 -1 -1] 1429 [ 2 2 -1 -1 -1]] 1430 [[ 3 3 3 -1 -1] 1431 [ 4 4 4 4 -1]] 1432 >>> # Components of nested elements can be padded independently. 1433 >>> elements = [([1, 2, 3], [10]), 1434 ... ([4, 5], [11, 12])] 1435 >>> dataset = tf.data.Dataset.from_generator( 1436 ... lambda: iter(elements), (tf.int32, tf.int32)) 1437 >>> # Pad the first component of the tuple to length 4, and the second 1438 >>> # component to the smallest size that fits. 1439 >>> dataset = dataset.padded_batch(2, 1440 ... padded_shapes=([4], [None]), 1441 ... padding_values=(-1, 100)) 1442 >>> list(dataset.as_numpy_iterator()) 1443 [(array([[ 1, 2, 3, -1], [ 4, 5, -1, -1]], dtype=int32), 1444 array([[ 10, 100], [ 11, 12]], dtype=int32))] 1445 1446 See also `tf.data.experimental.dense_to_sparse_batch`, which combines 1447 elements that may have different shapes into a `tf.SparseTensor`. 1448 1449 Args: 1450 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 1451 consecutive elements of this dataset to combine in a single batch. 1452 padded_shapes: (Optional.) A nested structure of `tf.TensorShape` or 1453 `tf.int64` vector tensor-like objects representing the shape to which 1454 the respective component of each input element should be padded prior 1455 to batching. Any unknown dimensions will be padded to the maximum size 1456 of that dimension in each batch. If unset, all dimensions of all 1457 components are padded to the maximum size in the batch. `padded_shapes` 1458 must be set if any component has an unknown rank. 1459 padding_values: (Optional.) A nested structure of scalar-shaped 1460 `tf.Tensor`, representing the padding values to use for the respective 1461 components. None represents that the nested structure should be padded 1462 with default values. Defaults are `0` for numeric types and the empty 1463 string for string types. 1464 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 1465 whether the last batch should be dropped in the case it has fewer than 1466 `batch_size` elements; the default behavior is not to drop the smaller 1467 batch. 1468 1469 Returns: 1470 Dataset: A `Dataset`. 1471 1472 Raises: 1473 ValueError: If a component has an unknown rank, and the `padded_shapes` 1474 argument is not set. 1475 """ 1476 if padded_shapes is None: 1477 padded_shapes = get_legacy_output_shapes(self) 1478 # A `tf.TensorShape` only is only falsey if its *rank* is unknown: 1479 # bool(tf.TensorShape(None)) is False 1480 if not all(nest.flatten(padded_shapes)): 1481 raise ValueError("You must set the `padded_shapes` argument to " 1482 "`Dataset.padded_batch` if any component of its input" 1483 "has an unknown rank") 1484 return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values, 1485 drop_remainder) 1486 1487 def map(self, map_func, num_parallel_calls=None): 1488 """Maps `map_func` across the elements of this dataset. 1489 1490 This transformation applies `map_func` to each element of this dataset, and 1491 returns a new dataset containing the transformed elements, in the same 1492 order as they appeared in the input. `map_func` can be used to change both 1493 the values and the structure of a dataset's elements. For example, adding 1 1494 to each element, or projecting a subset of element components. 1495 1496 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] 1497 >>> dataset = dataset.map(lambda x: x + 1) 1498 >>> list(dataset.as_numpy_iterator()) 1499 [2, 3, 4, 5, 6] 1500 1501 The input signature of `map_func` is determined by the structure of each 1502 element in this dataset. 1503 1504 >>> dataset = Dataset.range(5) 1505 >>> # `map_func` takes a single argument of type `tf.Tensor` with the same 1506 >>> # shape and dtype. 1507 >>> result = dataset.map(lambda x: x + 1) 1508 1509 >>> # Each element is a tuple containing two `tf.Tensor` objects. 1510 >>> elements = [(1, "foo"), (2, "bar"), (3, "baz)")] 1511 >>> dataset = tf.data.Dataset.from_generator( 1512 ... lambda: elements, (tf.int32, tf.string)) 1513 >>> # `map_func` takes two arguments of type `tf.Tensor`. This function 1514 >>> # projects out just the first component. 1515 >>> result = dataset.map(lambda x_int, y_str: x_int) 1516 >>> list(result.as_numpy_iterator()) 1517 [1, 2, 3] 1518 1519 >>> # Each element is a dictionary mapping strings to `tf.Tensor` objects. 1520 >>> elements = ([{"a": 1, "b": "foo"}, 1521 ... {"a": 2, "b": "bar"}, 1522 ... {"a": 3, "b": "baz"}]) 1523 >>> dataset = tf.data.Dataset.from_generator( 1524 ... lambda: elements, {"a": tf.int32, "b": tf.string}) 1525 >>> # `map_func` takes a single argument of type `dict` with the same keys 1526 >>> # as the elements. 1527 >>> result = dataset.map(lambda d: str(d["a"]) + d["b"]) 1528 1529 The value or values returned by `map_func` determine the structure of each 1530 element in the returned dataset. 1531 1532 >>> dataset = tf.data.Dataset.range(3) 1533 >>> # `map_func` returns two `tf.Tensor` objects. 1534 >>> def g(x): 1535 ... return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"]) 1536 >>> result = dataset.map(g) 1537 >>> result.element_spec 1538 (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(3,), \ 1539dtype=tf.string, name=None)) 1540 >>> # Python primitives, lists, and NumPy arrays are implicitly converted to 1541 >>> # `tf.Tensor`. 1542 >>> def h(x): 1543 ... return 37.0, ["Foo", "Bar"], np.array([1.0, 2.0], dtype=np.float64) 1544 >>> result = dataset.map(h) 1545 >>> result.element_spec 1546 (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(2,), \ 1547dtype=tf.string, name=None), TensorSpec(shape=(2,), dtype=tf.float64, \ 1548name=None)) 1549 >>> # `map_func` can return nested structures. 1550 >>> def i(x): 1551 ... return (37.0, [42, 16]), "foo" 1552 >>> result = dataset.map(i) 1553 >>> result.element_spec 1554 ((TensorSpec(shape=(), dtype=tf.float32, name=None), 1555 TensorSpec(shape=(2,), dtype=tf.int32, name=None)), 1556 TensorSpec(shape=(), dtype=tf.string, name=None)) 1557 1558 `map_func` can accept as arguments and return any type of dataset element. 1559 1560 Note that irrespective of the context in which `map_func` is defined (eager 1561 vs. graph), tf.data traces the function and executes it as a graph. To use 1562 Python code inside of the function you have two options: 1563 1564 1) Rely on AutoGraph to convert Python code into an equivalent graph 1565 computation. The downside of this approach is that AutoGraph can convert 1566 some but not all Python code. 1567 1568 2) Use `tf.py_function`, which allows you to write arbitrary Python code but 1569 will generally result in worse performance than 1). For example: 1570 1571 >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world']) 1572 >>> # transform a string tensor to upper case string using a Python function 1573 >>> def upper_case_fn(t: tf.Tensor): 1574 ... return t.numpy().decode('utf-8').upper() 1575 >>> d = d.map(lambda x: tf.py_function(func=upper_case_fn, 1576 ... inp=[x], Tout=tf.string)) 1577 >>> list(d.as_numpy_iterator()) 1578 [b'HELLO', b'WORLD'] 1579 1580 Args: 1581 map_func: A function mapping a dataset element to another dataset element. 1582 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 1583 representing the number elements to process asynchronously in parallel. 1584 If not specified, elements will be processed sequentially. If the value 1585 `tf.data.experimental.AUTOTUNE` is used, then the number of parallel 1586 calls is set dynamically based on available CPU. 1587 1588 Returns: 1589 Dataset: A `Dataset`. 1590 """ 1591 if num_parallel_calls is None: 1592 return MapDataset(self, map_func, preserve_cardinality=True) 1593 else: 1594 return ParallelMapDataset( 1595 self, map_func, num_parallel_calls, preserve_cardinality=True) 1596 1597 def flat_map(self, map_func): 1598 """Maps `map_func` across this dataset and flattens the result. 1599 1600 Use `flat_map` if you want to make sure that the order of your dataset 1601 stays the same. For example, to flatten a dataset of batches into a 1602 dataset of their elements: 1603 1604 >>> dataset = Dataset.from_tensor_slices([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 1605 >>> dataset = dataset.flat_map(lambda x: Dataset.from_tensor_slices(x)) 1606 >>> list(dataset.as_numpy_iterator()) 1607 [1, 2, 3, 4, 5, 6, 7, 8, 9] 1608 1609 `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since 1610 `flat_map` produces the same output as 1611 `tf.data.Dataset.interleave(cycle_length=1)` 1612 1613 Args: 1614 map_func: A function mapping a dataset element to a dataset. 1615 1616 Returns: 1617 Dataset: A `Dataset`. 1618 """ 1619 return FlatMapDataset(self, map_func) 1620 1621 def interleave(self, 1622 map_func, 1623 cycle_length=AUTOTUNE, 1624 block_length=1, 1625 num_parallel_calls=None, 1626 deterministic=None): 1627 """Maps `map_func` across this dataset, and interleaves the results. 1628 1629 For example, you can use `Dataset.interleave()` to process many input files 1630 concurrently: 1631 1632 >>> # Preprocess 4 files concurrently, and interleave blocks of 16 records 1633 >>> # from each file. 1634 >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt", 1635 ... "/var/data/file3.txt", "/var/data/file4.txt"] 1636 >>> dataset = tf.data.Dataset.from_tensor_slices(filenames) 1637 >>> def parse_fn(filename): 1638 ... return tf.data.Dataset.range(10) 1639 >>> dataset = dataset.interleave(lambda x: 1640 ... tf.data.TextLineDataset(x).map(parse_fn, num_parallel_calls=1), 1641 ... cycle_length=4, block_length=16) 1642 1643 The `cycle_length` and `block_length` arguments control the order in which 1644 elements are produced. `cycle_length` controls the number of input elements 1645 that are processed concurrently. If you set `cycle_length` to 1, this 1646 transformation will handle one input element at a time, and will produce 1647 identical results to `tf.data.Dataset.flat_map`. In general, 1648 this transformation will apply `map_func` to `cycle_length` input elements, 1649 open iterators on the returned `Dataset` objects, and cycle through them 1650 producing `block_length` consecutive elements from each iterator, and 1651 consuming the next input element each time it reaches the end of an 1652 iterator. 1653 1654 For example: 1655 1656 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] 1657 >>> # NOTE: New lines indicate "block" boundaries. 1658 >>> dataset = dataset.interleave( 1659 ... lambda x: Dataset.from_tensors(x).repeat(6), 1660 ... cycle_length=2, block_length=4) 1661 >>> list(dataset.as_numpy_iterator()) 1662 [1, 1, 1, 1, 1663 2, 2, 2, 2, 1664 1, 1, 1665 2, 2, 1666 3, 3, 3, 3, 1667 4, 4, 4, 4, 1668 3, 3, 1669 4, 4, 1670 5, 5, 5, 5, 1671 5, 5] 1672 1673 NOTE: The order of elements yielded by this transformation is 1674 deterministic, as long as `map_func` is a pure function and 1675 `deterministic=True`. If `map_func` contains any stateful operations, the 1676 order in which that state is accessed is undefined. 1677 1678 Performance can often be improved by setting `num_parallel_calls` so that 1679 `interleave` will use multiple threads to fetch elements. If determinism 1680 isn't required, it can also improve performance to set 1681 `deterministic=False`. 1682 1683 >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt", 1684 ... "/var/data/file3.txt", "/var/data/file4.txt"] 1685 >>> dataset = tf.data.Dataset.from_tensor_slices(filenames) 1686 >>> dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x), 1687 ... cycle_length=4, num_parallel_calls=tf.data.experimental.AUTOTUNE, 1688 ... deterministic=False) 1689 1690 Args: 1691 map_func: A function mapping a dataset element to a dataset. 1692 cycle_length: (Optional.) The number of input elements that will be 1693 processed concurrently. If not specified, the value will be derived from 1694 the number of available CPU cores. If the `num_parallel_calls` argument 1695 is set to `tf.data.experimental.AUTOTUNE`, the `cycle_length` argument 1696 also identifies the maximum degree of parallelism. 1697 block_length: (Optional.) The number of consecutive elements to produce 1698 from each input element before cycling to another input element. 1699 num_parallel_calls: (Optional.) If specified, the implementation creates a 1700 threadpool, which is used to fetch inputs from cycle elements 1701 asynchronously and in parallel. The default behavior is to fetch inputs 1702 from cycle elements synchronously with no parallelism. If the value 1703 `tf.data.experimental.AUTOTUNE` is used, then the number of parallel 1704 calls is set dynamically based on available CPU. 1705 deterministic: (Optional.) A boolean controlling whether determinism 1706 should be traded for performance by allowing elements to be produced out 1707 of order. If `deterministic` is `None`, the 1708 `tf.data.Options.experimental_deterministic` dataset option (`True` by 1709 default) is used to decide whether to produce elements 1710 deterministically. 1711 1712 Returns: 1713 Dataset: A `Dataset`. 1714 """ 1715 if num_parallel_calls is None: 1716 return InterleaveDataset(self, map_func, cycle_length, block_length) 1717 else: 1718 return ParallelInterleaveDataset(self, map_func, cycle_length, 1719 block_length, num_parallel_calls, 1720 deterministic) 1721 1722 def filter(self, predicate): 1723 """Filters this dataset according to `predicate`. 1724 1725 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 1726 >>> dataset = dataset.filter(lambda x: x < 3) 1727 >>> list(dataset.as_numpy_iterator()) 1728 [1, 2] 1729 >>> # `tf.math.equal(x, y)` is required for equality comparison 1730 >>> def filter_fn(x): 1731 ... return tf.math.equal(x, 1) 1732 >>> dataset = dataset.filter(filter_fn) 1733 >>> list(dataset.as_numpy_iterator()) 1734 [1] 1735 1736 Args: 1737 predicate: A function mapping a dataset element to a boolean. 1738 1739 Returns: 1740 Dataset: The `Dataset` containing the elements of this dataset for which 1741 `predicate` is `True`. 1742 """ 1743 return FilterDataset(self, predicate) 1744 1745 def apply(self, transformation_func): 1746 """Applies a transformation function to this dataset. 1747 1748 `apply` enables chaining of custom `Dataset` transformations, which are 1749 represented as functions that take one `Dataset` argument and return a 1750 transformed `Dataset`. 1751 1752 >>> dataset = tf.data.Dataset.range(100) 1753 >>> def dataset_fn(ds): 1754 ... return ds.filter(lambda x: x < 5) 1755 >>> dataset = dataset.apply(dataset_fn) 1756 >>> list(dataset.as_numpy_iterator()) 1757 [0, 1, 2, 3, 4] 1758 1759 Args: 1760 transformation_func: A function that takes one `Dataset` argument and 1761 returns a `Dataset`. 1762 1763 Returns: 1764 Dataset: The `Dataset` returned by applying `transformation_func` to this 1765 dataset. 1766 """ 1767 dataset = transformation_func(self) 1768 if not isinstance(dataset, DatasetV2): 1769 raise TypeError( 1770 "`transformation_func` must return a Dataset. Got {}.".format( 1771 dataset)) 1772 dataset._input_datasets = [self] # pylint: disable=protected-access 1773 return dataset 1774 1775 def window(self, size, shift=None, stride=1, drop_remainder=False): 1776 """Combines (nests of) input elements into a dataset of (nests of) windows. 1777 1778 A "window" is a finite dataset of flat elements of size `size` (or possibly 1779 fewer if there are not enough input elements to fill the window and 1780 `drop_remainder` evaluates to false). 1781 1782 The `stride` argument determines the stride of the input elements, and the 1783 `shift` argument determines the shift of the window. 1784 1785 >>> dataset = tf.data.Dataset.range(7).window(2) 1786 >>> for window in dataset: 1787 ... print(list(window.as_numpy_iterator())) 1788 [0, 1] 1789 [2, 3] 1790 [4, 5] 1791 [6] 1792 >>> dataset = tf.data.Dataset.range(7).window(3, 2, 1, True) 1793 >>> for window in dataset: 1794 ... print(list(window.as_numpy_iterator())) 1795 [0, 1, 2] 1796 [2, 3, 4] 1797 [4, 5, 6] 1798 >>> dataset = tf.data.Dataset.range(7).window(3, 1, 2, True) 1799 >>> for window in dataset: 1800 ... print(list(window.as_numpy_iterator())) 1801 [0, 2, 4] 1802 [1, 3, 5] 1803 [2, 4, 6] 1804 1805 Note that when the `window` transformation is applied to a dataset of 1806 nested elements, it produces a dataset of nested windows. 1807 1808 >>> nested = ([1, 2, 3, 4], [5, 6, 7, 8]) 1809 >>> dataset = tf.data.Dataset.from_tensor_slices(nested).window(2) 1810 >>> for window in dataset: 1811 ... def to_numpy(ds): 1812 ... return list(ds.as_numpy_iterator()) 1813 ... print(tuple(to_numpy(component) for component in window)) 1814 ([1, 2], [5, 6]) 1815 ([3, 4], [7, 8]) 1816 1817 >>> dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3, 4]}) 1818 >>> dataset = dataset.window(2) 1819 >>> for window in dataset: 1820 ... def to_numpy(ds): 1821 ... return list(ds.as_numpy_iterator()) 1822 ... print({'a': to_numpy(window['a'])}) 1823 {'a': [1, 2]} 1824 {'a': [3, 4]} 1825 1826 Args: 1827 size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements 1828 of the input dataset to combine into a window. 1829 shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 1830 forward shift of the sliding window in each iteration. Defaults to 1831 `size`. 1832 stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 1833 stride of the input elements in the sliding window. 1834 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 1835 whether a window should be dropped in case its size is smaller than 1836 `window_size`. 1837 1838 Returns: 1839 Dataset: A `Dataset` of (nests of) windows -- a finite datasets of flat 1840 elements created from the (nests of) input elements. 1841 1842 """ 1843 if shift is None: 1844 shift = size 1845 return WindowDataset(self, size, shift, stride, drop_remainder) 1846 1847 def reduce(self, initial_state, reduce_func): 1848 """Reduces the input dataset to a single element. 1849 1850 The transformation calls `reduce_func` successively on every element of 1851 the input dataset until the dataset is exhausted, aggregating information in 1852 its internal state. The `initial_state` argument is used for the initial 1853 state and the final state is returned as the result. 1854 1855 >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1).numpy() 1856 5 1857 >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y).numpy() 1858 10 1859 1860 Args: 1861 initial_state: An element representing the initial state of the 1862 transformation. 1863 reduce_func: A function that maps `(old_state, input_element)` to 1864 `new_state`. It must take two arguments and return a new element 1865 The structure of `new_state` must match the structure of 1866 `initial_state`. 1867 1868 Returns: 1869 A dataset element corresponding to the final state of the transformation. 1870 1871 """ 1872 1873 with ops.name_scope("initial_state"): 1874 initial_state = structure.normalize_element(initial_state) 1875 state_structure = structure.type_spec_from_value(initial_state) 1876 1877 # Iteratively rerun the reduce function until reaching a fixed point on 1878 # `state_structure`. 1879 need_to_rerun = True 1880 while need_to_rerun: 1881 1882 wrapped_func = StructuredFunctionWrapper( 1883 reduce_func, 1884 "reduce()", 1885 input_structure=(state_structure, self.element_spec), 1886 add_to_graph=False) 1887 1888 # Extract and validate class information from the returned values. 1889 output_classes = wrapped_func.output_classes 1890 state_classes = nest.map_structure( 1891 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 1892 state_structure) 1893 for new_state_class, state_class in zip( 1894 nest.flatten(output_classes), nest.flatten(state_classes)): 1895 if not issubclass(new_state_class, state_class): 1896 raise TypeError( 1897 "The element classes for the new state must match the initial " 1898 "state. Expected %s; got %s." % 1899 (state_classes, wrapped_func.output_classes)) 1900 1901 # Extract and validate type information from the returned values. 1902 output_types = wrapped_func.output_types 1903 state_types = nest.map_structure( 1904 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 1905 state_structure) 1906 for new_state_type, state_type in zip( 1907 nest.flatten(output_types), nest.flatten(state_types)): 1908 if new_state_type != state_type: 1909 raise TypeError( 1910 "The element types for the new state must match the initial " 1911 "state. Expected %s; got %s." % 1912 (state_types, wrapped_func.output_types)) 1913 1914 # Extract shape information from the returned values. 1915 output_shapes = wrapped_func.output_shapes 1916 state_shapes = nest.map_structure( 1917 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 1918 state_structure) 1919 flat_state_shapes = nest.flatten(state_shapes) 1920 flat_new_state_shapes = nest.flatten(output_shapes) 1921 weakened_state_shapes = [ 1922 original.most_specific_compatible_shape(new) 1923 for original, new in zip(flat_state_shapes, flat_new_state_shapes) 1924 ] 1925 1926 need_to_rerun = False 1927 for original_shape, weakened_shape in zip(flat_state_shapes, 1928 weakened_state_shapes): 1929 if original_shape.ndims is not None and ( 1930 weakened_shape.ndims is None or 1931 original_shape.as_list() != weakened_shape.as_list()): 1932 need_to_rerun = True 1933 break 1934 1935 if need_to_rerun: 1936 # TODO(b/110122868): Support a "most specific compatible structure" 1937 # method for combining structures, to avoid using legacy structures 1938 # here. 1939 state_structure = structure.convert_legacy_structure( 1940 state_types, 1941 nest.pack_sequence_as(state_shapes, weakened_state_shapes), 1942 state_classes) 1943 1944 reduce_func = wrapped_func.function 1945 reduce_func.add_to_graph(ops.get_default_graph()) 1946 1947 dataset = self._apply_options() 1948 1949 # pylint: disable=protected-access 1950 return structure.from_compatible_tensor_list( 1951 state_structure, 1952 gen_dataset_ops.reduce_dataset( 1953 dataset._variant_tensor, 1954 structure.to_tensor_list(state_structure, initial_state), 1955 reduce_func.captured_inputs, 1956 f=reduce_func, 1957 output_shapes=structure.get_flat_tensor_shapes(state_structure), 1958 output_types=structure.get_flat_tensor_types(state_structure))) 1959 1960 def unbatch(self): 1961 """Splits elements of a dataset into multiple elements. 1962 1963 For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, 1964 where `B` may vary for each input element, then for each element in the 1965 dataset, the unbatched dataset will contain `B` consecutive elements 1966 of shape `[a0, a1, ...]`. 1967 1968 >>> elements = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ] 1969 >>> dataset = tf.data.Dataset.from_generator(lambda: elements, tf.int64) 1970 >>> dataset = dataset.unbatch() 1971 >>> list(dataset.as_numpy_iterator()) 1972 [1, 2, 3, 1, 2, 1, 2, 3, 4] 1973 1974 Returns: 1975 A `Dataset`. 1976 """ 1977 normalized_dataset = normalize_to_dense(self) 1978 return _UnbatchDataset(normalized_dataset) 1979 1980 def with_options(self, options): 1981 """Returns a new `tf.data.Dataset` with the given options set. 1982 1983 The options are "global" in the sense they apply to the entire dataset. 1984 If options are set multiple times, they are merged as long as different 1985 options do not use different non-default values. 1986 1987 >>> ds = tf.data.Dataset.range(5) 1988 >>> ds = ds.interleave(lambda x: tf.data.Dataset.range(5), 1989 ... cycle_length=3, 1990 ... num_parallel_calls=3) 1991 >>> options = tf.data.Options() 1992 >>> # This will make the interleave order non-deterministic. 1993 >>> options.experimental_deterministic = False 1994 >>> ds = ds.with_options(options) 1995 1996 Args: 1997 options: A `tf.data.Options` that identifies the options the use. 1998 1999 Returns: 2000 Dataset: A `Dataset` with the given options. 2001 2002 Raises: 2003 ValueError: when an option is set more than once to a non-default value 2004 """ 2005 return _OptionsDataset(self, options) 2006 2007 2008@tf_export(v1=["data.Dataset"]) 2009class DatasetV1(DatasetV2): 2010 """Represents a potentially large set of elements. 2011 2012 A `Dataset` can be used to represent an input pipeline as a 2013 collection of elements and a "logical plan" of transformations that act on 2014 those elements. 2015 """ 2016 2017 def __init__(self): 2018 try: 2019 variant_tensor = self._as_variant_tensor() 2020 except AttributeError as e: 2021 if "_as_variant_tensor" in str(e): 2022 raise AttributeError("Please use _variant_tensor instead of " 2023 "_as_variant_tensor() to obtain the variant " 2024 "associated with a dataset") 2025 raise AttributeError("{}: A likely cause of this error is that the super " 2026 "call for this dataset is not the last line of the " 2027 "__init__ method. The base class causes the " 2028 "_as_variant_tensor call in its constructor and " 2029 "if that uses attributes defined in the __init__ " 2030 "method, those attrs need to be defined before the " 2031 "super call.".format(e)) 2032 super(DatasetV1, self).__init__(variant_tensor) 2033 2034 @abc.abstractmethod 2035 def _as_variant_tensor(self): 2036 """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset. 2037 2038 Returns: 2039 A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset. 2040 """ 2041 raise NotImplementedError("Dataset._as_variant_tensor") 2042 2043 @deprecation.deprecated( 2044 None, "Use `for ... in dataset:` to iterate over a dataset. If using " 2045 "`tf.estimator`, return the `Dataset` object directly from your input " 2046 "function. As a last resort, you can use " 2047 "`tf.compat.v1.data.make_one_shot_iterator(dataset)`.") 2048 def make_one_shot_iterator(self): 2049 """Creates an `Iterator` for enumerating the elements of this dataset. 2050 2051 Note: The returned iterator will be initialized automatically. 2052 A "one-shot" iterator does not currently support re-initialization. 2053 2054 Returns: 2055 An `Iterator` over the elements of this dataset. 2056 """ 2057 return self._make_one_shot_iterator() 2058 2059 def _make_one_shot_iterator(self): # pylint: disable=missing-docstring 2060 if context.executing_eagerly(): 2061 return iterator_ops.OwnedIterator(self) 2062 2063 _ensure_same_dataset_graph(self) 2064 # Now that we create datasets at python object creation time, the capture 2065 # by value _make_dataset() function would try to capture these variant 2066 # tensor dataset inputs, which are marked as stateful ops and would throw 2067 # an error if we try and capture them. We therefore traverse the graph 2068 # to find all these ops and whitelist them so that the capturing 2069 # logic instead of throwing an error recreates these ops which is what was 2070 # happening before. 2071 all_ds_ops = traverse.obtain_all_variant_tensor_ops(self) 2072 graph_level_seed, op_level_seed = core_random_seed.get_seed(None) 2073 2074 # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is 2075 # a 0-argument function. 2076 @function.Defun(capture_by_value=True, whitelisted_stateful_ops=all_ds_ops) 2077 def _make_dataset(): 2078 """Factory function for a dataset.""" 2079 # NOTE(mrry): `Defun` does not capture the graph-level seed from the 2080 # enclosing graph, so if a graph-level seed is present we set the local 2081 # graph seed based on a combination of the graph- and op-level seeds. 2082 if graph_level_seed is not None: 2083 assert op_level_seed is not None 2084 core_random_seed.set_random_seed( 2085 (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1)) 2086 2087 dataset = self._apply_options() 2088 return dataset._variant_tensor # pylint: disable=protected-access 2089 2090 try: 2091 _make_dataset.add_to_graph(ops.get_default_graph()) 2092 except ValueError as err: 2093 if "Cannot capture a stateful node" in str(err): 2094 raise ValueError( 2095 "Failed to create a one-shot iterator for a dataset. " 2096 "`Dataset.make_one_shot_iterator()` does not support datasets that " 2097 "capture stateful objects, such as a `Variable` or `LookupTable`. " 2098 "In these cases, use `Dataset.make_initializable_iterator()`. " 2099 "(Original error: %s)" % err) 2100 else: 2101 six.reraise(ValueError, err) 2102 2103 # pylint: disable=protected-access 2104 return iterator_ops.Iterator( 2105 gen_dataset_ops.one_shot_iterator( 2106 dataset_factory=_make_dataset, **self._flat_structure), None, 2107 get_legacy_output_types(self), get_legacy_output_shapes(self), 2108 get_legacy_output_classes(self)) 2109 2110 @deprecation.deprecated( 2111 None, "Use `for ... in dataset:` to iterate over a dataset. If using " 2112 "`tf.estimator`, return the `Dataset` object directly from your input " 2113 "function. As a last resort, you can use " 2114 "`tf.compat.v1.data.make_initializable_iterator(dataset)`.") 2115 def make_initializable_iterator(self, shared_name=None): 2116 """Creates an `Iterator` for enumerating the elements of this dataset. 2117 2118 Note: The returned iterator will be in an uninitialized state, 2119 and you must run the `iterator.initializer` operation before using it: 2120 2121 ```python 2122 dataset = ... 2123 iterator = dataset.make_initializable_iterator() 2124 # ... 2125 sess.run(iterator.initializer) 2126 ``` 2127 2128 Args: 2129 shared_name: (Optional.) If non-empty, the returned iterator will be 2130 shared under the given name across multiple sessions that share the same 2131 devices (e.g. when using a remote server). 2132 2133 Returns: 2134 An `Iterator` over the elements of this dataset. 2135 2136 Raises: 2137 RuntimeError: If eager execution is enabled. 2138 """ 2139 2140 return self._make_initializable_iterator(shared_name) 2141 2142 def _make_initializable_iterator(self, shared_name=None): # pylint: disable=missing-docstring 2143 if context.executing_eagerly(): 2144 raise RuntimeError( 2145 "dataset.make_initializable_iterator is not supported when eager " 2146 "execution is enabled. Use `for element in dataset` instead.") 2147 _ensure_same_dataset_graph(self) 2148 dataset = self._apply_options() 2149 if shared_name is None: 2150 shared_name = "" 2151 iterator_resource = gen_dataset_ops.iterator_v2( 2152 container="", shared_name=shared_name, **self._flat_structure) 2153 with ops.colocate_with(iterator_resource): 2154 initializer = gen_dataset_ops.make_iterator( 2155 dataset._variant_tensor, # pylint: disable=protected-access 2156 iterator_resource) 2157 # pylint: disable=protected-access 2158 return iterator_ops.Iterator( 2159 iterator_resource, initializer, get_legacy_output_types(dataset), 2160 get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset)) 2161 2162 @property 2163 @deprecation.deprecated( 2164 None, "Use `tf.compat.v1.data.get_output_classes(dataset)`.") 2165 def output_classes(self): 2166 """Returns the class of each component of an element of this dataset. 2167 2168 Returns: 2169 A nested structure of Python `type` objects corresponding to each 2170 component of an element of this dataset. 2171 """ 2172 return nest.map_structure( 2173 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 2174 self.element_spec) 2175 2176 @property 2177 @deprecation.deprecated( 2178 None, "Use `tf.compat.v1.data.get_output_shapes(dataset)`.") 2179 def output_shapes(self): 2180 """Returns the shape of each component of an element of this dataset. 2181 2182 Returns: 2183 A nested structure of `tf.TensorShape` objects corresponding to each 2184 component of an element of this dataset. 2185 """ 2186 return nest.map_structure( 2187 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 2188 self.element_spec) 2189 2190 @property 2191 @deprecation.deprecated( 2192 None, "Use `tf.compat.v1.data.get_output_types(dataset)`.") 2193 def output_types(self): 2194 """Returns the type of each component of an element of this dataset. 2195 2196 Returns: 2197 A nested structure of `tf.DType` objects corresponding to each component 2198 of an element of this dataset. 2199 """ 2200 return nest.map_structure( 2201 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 2202 self.element_spec) 2203 2204 @property 2205 def element_spec(self): 2206 # TODO(b/110122868): Remove this override once all `Dataset` instances 2207 # implement `element_structure`. 2208 return structure.convert_legacy_structure( 2209 self.output_types, self.output_shapes, self.output_classes) 2210 2211 @staticmethod 2212 @functools.wraps(DatasetV2.from_tensors) 2213 def from_tensors(tensors): 2214 return DatasetV1Adapter(DatasetV2.from_tensors(tensors)) 2215 2216 @staticmethod 2217 @functools.wraps(DatasetV2.from_tensor_slices) 2218 def from_tensor_slices(tensors): 2219 return DatasetV1Adapter(DatasetV2.from_tensor_slices(tensors)) 2220 2221 @staticmethod 2222 @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") 2223 def from_sparse_tensor_slices(sparse_tensor): 2224 """Splits each rank-N `tf.SparseTensor` in this dataset row-wise. 2225 2226 Args: 2227 sparse_tensor: A `tf.SparseTensor`. 2228 2229 Returns: 2230 Dataset: A `Dataset` of rank-(N-1) sparse tensors. 2231 """ 2232 return DatasetV1Adapter(SparseTensorSliceDataset(sparse_tensor)) 2233 2234 @staticmethod 2235 @functools.wraps(DatasetV2.from_generator) 2236 def from_generator(generator, output_types, output_shapes=None, args=None): 2237 return DatasetV1Adapter(DatasetV2.from_generator( 2238 generator, output_types, output_shapes, args)) 2239 2240 @staticmethod 2241 @functools.wraps(DatasetV2.range) 2242 def range(*args, **kwargs): 2243 return DatasetV1Adapter(DatasetV2.range(*args, **kwargs)) 2244 2245 @staticmethod 2246 @functools.wraps(DatasetV2.zip) 2247 def zip(datasets): 2248 return DatasetV1Adapter(DatasetV2.zip(datasets)) 2249 2250 @functools.wraps(DatasetV2.concatenate) 2251 def concatenate(self, dataset): 2252 return DatasetV1Adapter(super(DatasetV1, self).concatenate(dataset)) 2253 2254 @functools.wraps(DatasetV2.prefetch) 2255 def prefetch(self, buffer_size): 2256 return DatasetV1Adapter(super(DatasetV1, self).prefetch(buffer_size)) 2257 2258 @staticmethod 2259 @functools.wraps(DatasetV2.list_files) 2260 def list_files(file_pattern, shuffle=None, seed=None): 2261 return DatasetV1Adapter(DatasetV2.list_files(file_pattern, shuffle, seed)) 2262 2263 @functools.wraps(DatasetV2.repeat) 2264 def repeat(self, count=None): 2265 return DatasetV1Adapter(super(DatasetV1, self).repeat(count)) 2266 2267 @functools.wraps(DatasetV2.shuffle) 2268 def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None): 2269 return DatasetV1Adapter(super(DatasetV1, self).shuffle( 2270 buffer_size, seed, reshuffle_each_iteration)) 2271 2272 @functools.wraps(DatasetV2.cache) 2273 def cache(self, filename=""): 2274 return DatasetV1Adapter(super(DatasetV1, self).cache(filename)) 2275 2276 @functools.wraps(DatasetV2.take) 2277 def take(self, count): 2278 return DatasetV1Adapter(super(DatasetV1, self).take(count)) 2279 2280 @functools.wraps(DatasetV2.skip) 2281 def skip(self, count): 2282 return DatasetV1Adapter(super(DatasetV1, self).skip(count)) 2283 2284 @functools.wraps(DatasetV2.shard) 2285 def shard(self, num_shards, index): 2286 return DatasetV1Adapter(super(DatasetV1, self).shard(num_shards, index)) 2287 2288 @functools.wraps(DatasetV2.batch) 2289 def batch(self, batch_size, drop_remainder=False): 2290 return DatasetV1Adapter(super(DatasetV1, self).batch( 2291 batch_size, drop_remainder)) 2292 2293 @functools.wraps(DatasetV2.padded_batch) 2294 def padded_batch(self, 2295 batch_size, 2296 padded_shapes=None, 2297 padding_values=None, 2298 drop_remainder=False): 2299 return DatasetV1Adapter(super(DatasetV1, self).padded_batch( 2300 batch_size, padded_shapes, padding_values, drop_remainder)) 2301 2302 @functools.wraps(DatasetV2.map) 2303 def map(self, map_func, num_parallel_calls=None): 2304 if num_parallel_calls is None: 2305 return DatasetV1Adapter( 2306 MapDataset(self, map_func, preserve_cardinality=False)) 2307 else: 2308 return DatasetV1Adapter( 2309 ParallelMapDataset( 2310 self, map_func, num_parallel_calls, preserve_cardinality=False)) 2311 2312 @deprecation.deprecated(None, "Use `tf.data.Dataset.map()") 2313 def map_with_legacy_function(self, map_func, num_parallel_calls=None): 2314 """Maps `map_func` across the elements of this dataset. 2315 2316 NOTE: This is an escape hatch for existing uses of `map` that do not work 2317 with V2 functions. New uses are strongly discouraged and existing uses 2318 should migrate to `map` as this method will be removed in V2. 2319 2320 Args: 2321 map_func: A function mapping a nested structure of tensors (having shapes 2322 and types defined by `self.output_shapes` and `self.output_types`) to 2323 another nested structure of tensors. 2324 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 2325 representing the number elements to process asynchronously in parallel. 2326 If not specified, elements will be processed sequentially. If the value 2327 `tf.data.experimental.AUTOTUNE` is used, then the number of parallel 2328 calls is set dynamically based on available CPU. 2329 2330 Returns: 2331 Dataset: A `Dataset`. 2332 """ 2333 if num_parallel_calls is None: 2334 return DatasetV1Adapter( 2335 MapDataset( 2336 self, 2337 map_func, 2338 preserve_cardinality=False, 2339 use_legacy_function=True)) 2340 else: 2341 return DatasetV1Adapter( 2342 ParallelMapDataset( 2343 self, 2344 map_func, 2345 num_parallel_calls, 2346 preserve_cardinality=False, 2347 use_legacy_function=True)) 2348 2349 @functools.wraps(DatasetV2.flat_map) 2350 def flat_map(self, map_func): 2351 return DatasetV1Adapter(super(DatasetV1, self).flat_map(map_func)) 2352 2353 @functools.wraps(DatasetV2.interleave) 2354 def interleave(self, 2355 map_func, 2356 cycle_length=AUTOTUNE, 2357 block_length=1, 2358 num_parallel_calls=None, 2359 deterministic=None): 2360 return DatasetV1Adapter( 2361 super(DatasetV1, self).interleave(map_func, cycle_length, block_length, 2362 num_parallel_calls, deterministic)) 2363 2364 @functools.wraps(DatasetV2.filter) 2365 def filter(self, predicate): 2366 return DatasetV1Adapter(super(DatasetV1, self).filter(predicate)) 2367 2368 @deprecation.deprecated(None, "Use `tf.data.Dataset.filter()") 2369 def filter_with_legacy_function(self, predicate): 2370 """Filters this dataset according to `predicate`. 2371 2372 NOTE: This is an escape hatch for existing uses of `filter` that do not work 2373 with V2 functions. New uses are strongly discouraged and existing uses 2374 should migrate to `filter` as this method will be removed in V2. 2375 2376 Args: 2377 predicate: A function mapping a nested structure of tensors (having shapes 2378 and types defined by `self.output_shapes` and `self.output_types`) to a 2379 scalar `tf.bool` tensor. 2380 2381 Returns: 2382 Dataset: The `Dataset` containing the elements of this dataset for which 2383 `predicate` is `True`. 2384 """ 2385 return FilterDataset(self, predicate, use_legacy_function=True) 2386 2387 @functools.wraps(DatasetV2.apply) 2388 def apply(self, transformation_func): 2389 return DatasetV1Adapter(super(DatasetV1, self).apply(transformation_func)) 2390 2391 @functools.wraps(DatasetV2.window) 2392 def window(self, size, shift=None, stride=1, drop_remainder=False): 2393 return DatasetV1Adapter(super(DatasetV1, self).window( 2394 size, shift, stride, drop_remainder)) 2395 2396 @functools.wraps(DatasetV2.unbatch) 2397 def unbatch(self): 2398 return DatasetV1Adapter(super(DatasetV1, self).unbatch()) 2399 2400 @functools.wraps(DatasetV2.with_options) 2401 def with_options(self, options): 2402 return DatasetV1Adapter(super(DatasetV1, self).with_options(options)) 2403 2404 2405if tf2.enabled(): 2406 Dataset = DatasetV2 2407else: 2408 Dataset = DatasetV1 2409 2410 2411class DatasetV1Adapter(DatasetV1): 2412 """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API.""" 2413 2414 def __init__(self, dataset): 2415 self._dataset = dataset 2416 super(DatasetV1Adapter, self).__init__() 2417 2418 def _as_variant_tensor(self): 2419 return self._dataset._variant_tensor # pylint: disable=protected-access 2420 2421 def _has_captured_ref(self): 2422 return self._dataset._has_captured_ref() # pylint: disable=protected-access 2423 2424 def _inputs(self): 2425 return self._dataset._inputs() # pylint: disable=protected-access 2426 2427 def _functions(self): 2428 return self._dataset._functions() # pylint: disable=protected-access 2429 2430 def options(self): 2431 return self._dataset.options() 2432 2433 @property 2434 def element_spec(self): 2435 return self._dataset.element_spec # pylint: disable=protected-access 2436 2437 def __iter__(self): 2438 return iter(self._dataset) 2439 2440 2441def _ensure_same_dataset_graph(dataset): 2442 """Walks the dataset graph to ensure all datasets come from the same graph.""" 2443 # pylint: disable=protected-access 2444 current_graph = ops.get_default_graph() 2445 bfs_q = Queue.Queue() 2446 bfs_q.put(dataset) 2447 visited = [] 2448 while not bfs_q.empty(): 2449 ds = bfs_q.get() 2450 visited.append(ds) 2451 ds_graph = ds._graph 2452 if current_graph != ds_graph: 2453 raise ValueError( 2454 "The graph (" + str(current_graph) + ") of the iterator is different " 2455 "from the graph (" + str(ds_graph) + ") the dataset: " + 2456 str(ds._variant_tensor) + " was created in. If you are using the " 2457 "Estimator API, make sure that no part of the dataset returned by " 2458 "the `input_fn` function is defined outside the `input_fn` function. " 2459 "Please ensure that all datasets in the pipeline are created in the " 2460 "same graph as the iterator.") 2461 for input_ds in ds._inputs(): 2462 if input_ds not in visited: 2463 bfs_q.put(input_ds) 2464 2465 2466@tf_export(v1=["data.make_one_shot_iterator"]) 2467def make_one_shot_iterator(dataset): 2468 """Creates a `tf.compat.v1.data.Iterator` for enumerating dataset elements. 2469 2470 Note: The returned iterator will be initialized automatically. 2471 A "one-shot" iterator does not support re-initialization. 2472 2473 Args: 2474 dataset: A `tf.data.Dataset`. 2475 2476 Returns: 2477 A `tf.compat.v1.data.Iterator` over the elements of this dataset. 2478 """ 2479 try: 2480 # Call the defined `_make_one_shot_iterator()` if there is one, because some 2481 # datasets (e.g. for prefetching) override its behavior. 2482 return dataset._make_one_shot_iterator() # pylint: disable=protected-access 2483 except AttributeError: 2484 return DatasetV1Adapter(dataset)._make_one_shot_iterator() # pylint: disable=protected-access 2485 2486 2487@tf_export(v1=["data.make_initializable_iterator"]) 2488def make_initializable_iterator(dataset, shared_name=None): 2489 """Creates a `tf.compat.v1.data.Iterator` for enumerating the elements of a dataset. 2490 2491 Note: The returned iterator will be in an uninitialized state, 2492 and you must run the `iterator.initializer` operation before using it: 2493 2494 ```python 2495 dataset = ... 2496 iterator = tf.compat.v1.data.make_initializable_iterator(dataset) 2497 # ... 2498 sess.run(iterator.initializer) 2499 ``` 2500 2501 Args: 2502 dataset: A `tf.data.Dataset`. 2503 shared_name: (Optional.) If non-empty, the returned iterator will be shared 2504 under the given name across multiple sessions that share the same devices 2505 (e.g. when using a remote server). 2506 2507 Returns: 2508 A `tf.compat.v1.data.Iterator` over the elements of `dataset`. 2509 2510 Raises: 2511 RuntimeError: If eager execution is enabled. 2512 """ 2513 try: 2514 # Call the defined `_make_initializable_iterator()` if there is one, because 2515 # some datasets (e.g. for prefetching) override its behavior. 2516 return dataset._make_initializable_iterator(shared_name) # pylint: disable=protected-access 2517 except AttributeError: 2518 return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name) # pylint: disable=protected-access 2519 2520 2521@tf_export("data.experimental.get_structure") 2522def get_structure(dataset_or_iterator): 2523 """Returns the type specification of an element of a `Dataset` or `Iterator`. 2524 2525 Args: 2526 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 2527 2528 Returns: 2529 A nested structure of `tf.TypeSpec` objects matching the structure of an 2530 element of `dataset_or_iterator` and spacifying the type of individal 2531 components. 2532 2533 Raises: 2534 TypeError: If `dataset_or_iterator` is not a `Dataset` or `Iterator` object. 2535 """ 2536 try: 2537 return dataset_or_iterator.element_spec # pylint: disable=protected-access 2538 except AttributeError: 2539 raise TypeError("`dataset_or_iterator` must be a Dataset or Iterator " 2540 "object, but got %s." % type(dataset_or_iterator)) 2541 2542 2543@tf_export(v1=["data.get_output_classes"]) 2544def get_legacy_output_classes(dataset_or_iterator): 2545 """Returns the output classes of a `Dataset` or `Iterator` elements. 2546 2547 This utility method replaces the deprecated-in-V2 2548 `tf.compat.v1.Dataset.output_classes` property. 2549 2550 Args: 2551 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 2552 2553 Returns: 2554 A nested structure of Python `type` objects matching the structure of the 2555 dataset / iterator elements and specifying the class of the individual 2556 components. 2557 """ 2558 return nest.map_structure( 2559 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 2560 get_structure(dataset_or_iterator)) 2561 2562 2563@tf_export(v1=["data.get_output_shapes"]) 2564def get_legacy_output_shapes(dataset_or_iterator): 2565 """Returns the output shapes of a `Dataset` or `Iterator` elements. 2566 2567 This utility method replaces the deprecated-in-V2 2568 `tf.compat.v1.Dataset.output_shapes` property. 2569 2570 Args: 2571 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 2572 2573 Returns: 2574 A nested structure of `tf.TensorShape` objects matching the structure of 2575 the dataset / iterator elements and specifying the shape of the individual 2576 components. 2577 """ 2578 return nest.map_structure( 2579 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 2580 get_structure(dataset_or_iterator)) 2581 2582 2583@tf_export(v1=["data.get_output_types"]) 2584def get_legacy_output_types(dataset_or_iterator): 2585 """Returns the output shapes of a `Dataset` or `Iterator` elements. 2586 2587 This utility method replaces the deprecated-in-V2 2588 `tf.compat.v1.Dataset.output_types` property. 2589 2590 Args: 2591 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 2592 2593 Returns: 2594 A nested structure of `tf.DType` objects objects matching the structure of 2595 dataset / iterator elements and specifying the shape of the individual 2596 components. 2597 """ 2598 return nest.map_structure( 2599 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 2600 get_structure(dataset_or_iterator)) 2601 2602 2603@tf_export("data.Options") 2604class Options(options_lib.OptionsBase): 2605 """Represents options for tf.data.Dataset. 2606 2607 An `Options` object can be, for instance, used to control which graph 2608 optimizations to apply or whether to use performance modeling to dynamically 2609 tune the parallelism of operations such as `tf.data.Dataset.map` or 2610 `tf.data.Dataset.interleave`. 2611 2612 After constructing an `Options` object, use `dataset.with_options(options)` to 2613 apply the options to a dataset. 2614 2615 >>> dataset = tf.data.Dataset.range(3) 2616 >>> options = tf.data.Options() 2617 >>> # Set options here. 2618 >>> dataset = dataset.with_options(options) 2619 """ 2620 2621 experimental_deterministic = options_lib.create_option( 2622 name="experimental_deterministic", 2623 ty=bool, 2624 docstring= 2625 "Whether the outputs need to be produced in deterministic order. If None," 2626 " defaults to True.") 2627 2628 experimental_distribute = options_lib.create_option( 2629 name="experimental_distribute", 2630 ty=distribute_options.DistributeOptions, 2631 docstring= 2632 "The distribution strategy options associated with the dataset. See " 2633 "`tf.data.experimental.DistributeOptions` for more details.", 2634 default_factory=distribute_options.DistributeOptions) 2635 2636 experimental_optimization = options_lib.create_option( 2637 name="experimental_optimization", 2638 ty=optimization_options.OptimizationOptions, 2639 docstring= 2640 "The optimization options associated with the dataset. See " 2641 "`tf.data.experimental.OptimizationOptions` for more details.", 2642 default_factory=optimization_options.OptimizationOptions) 2643 2644 experimental_slack = options_lib.create_option( 2645 name="experimental_slack", 2646 ty=bool, 2647 docstring="Whether to introduce 'slack' in the last `prefetch` of the " 2648 "input pipeline, if it exists. This may reduce CPU contention with " 2649 "accelerator host-side activity at the start of a step. The slack " 2650 "frequency is determined by the number of devices attached to this " 2651 "input pipeline. If None, defaults to False.") 2652 2653 experimental_stats = options_lib.create_option( 2654 name="experimental_stats", 2655 ty=stats_options.StatsOptions, 2656 docstring= 2657 "The statistics options associated with the dataset. See " 2658 "`tf.data.experimental.StatsOptions` for more details.", 2659 default_factory=stats_options.StatsOptions) 2660 2661 experimental_threading = options_lib.create_option( 2662 name="experimental_threading", 2663 ty=threading_options.ThreadingOptions, 2664 docstring= 2665 "The threading options associated with the dataset. See " 2666 "`tf.data.experimental.ThreadingOptions` for more details.", 2667 default_factory=threading_options.ThreadingOptions) 2668 2669 experimental_external_state_policy = options_lib.create_option( 2670 name="experimental_external_state_policy", 2671 ty=distribute_options.ExternalStatePolicy, 2672 docstring="By default, tf.data will refuse to serialize a dataset or " 2673 "checkpoint its iterator if the dataset contains a stateful op as the " 2674 "serialization / checkpointing won't be able to capture its state. " 2675 "Users can -- at their own risk -- override this restriction by " 2676 "explicitly specifying that they are fine throwing away the state " 2677 "in these ops. There are three settings available - IGNORE: in which we" 2678 "completely ignore any state; WARN: We warn the user that some state " 2679 "might be thrown away; FAIL: We fail if any state is being captured.", 2680 default_factory=lambda: distribute_options.ExternalStatePolicy.WARN) 2681 2682 def _graph_rewrites(self): 2683 """Produces the list of enabled static graph rewrites.""" 2684 result = [] 2685 if self.experimental_optimization is not None: 2686 result.extend(self.experimental_optimization._graph_rewrites()) # pylint: disable=protected-access 2687 else: 2688 # Apply default options 2689 result.extend( 2690 optimization_options.OptimizationOptions()._graph_rewrites()) # pylint: disable=protected-access 2691 2692 if self.experimental_deterministic is False: 2693 result.append("make_sloppy") 2694 if self.experimental_stats and self.experimental_stats.latency_all_edges: 2695 result.append("latency_all_edges") 2696 if self.experimental_slack: 2697 result.append("slack") 2698 if (self.experimental_distribute and 2699 self.experimental_distribute._make_stateless): # pylint: disable=protected-access 2700 result.append("make_stateless") 2701 return result 2702 2703 def _graph_rewrite_configs(self): 2704 """Produces the list of configurations for enabled graph optimizations.""" 2705 result = [] 2706 if self.experimental_optimization: 2707 result.extend(self.experimental_optimization._graph_rewrite_configs()) # pylint: disable=protected-access 2708 2709 if self.experimental_slack: 2710 num_devices = self.experimental_distribute.num_devices 2711 if num_devices is None: 2712 num_devices = 1 2713 result.append("slack:slack_period:%d" % num_devices) 2714 return result 2715 2716 def _autotune_settings(self): 2717 if self.experimental_optimization is not None: 2718 return self.experimental_optimization._autotune_settings() # pylint: disable=protected-access 2719 2720 # Return default autotune options 2721 return optimization_options.OptimizationOptions()._autotune_settings() # pylint: disable=protected-access 2722 2723 def merge(self, options): 2724 """Merges itself with the given `tf.data.Options`. 2725 2726 The given `tf.data.Options` can be merged as long as there does not exist an 2727 attribute that is set to different values in `self` and `options`. 2728 2729 Args: 2730 options: a `tf.data.Options` to merge with 2731 2732 Raises: 2733 ValueError: if the given `tf.data.Options` cannot be merged 2734 2735 Returns: 2736 New `tf.data.Options()` object which is the result of merging self with 2737 the input `tf.data.Options`. 2738 """ 2739 return options_lib.merge_options(self, options) 2740 2741 2742class DatasetSource(DatasetV2): 2743 """Abstract class representing a dataset with no inputs.""" 2744 2745 def _inputs(self): 2746 return [] 2747 2748 2749class UnaryDataset(DatasetV2): 2750 """Abstract class representing a dataset with one input.""" 2751 2752 def __init__(self, input_dataset, variant_tensor): 2753 self._input_dataset = input_dataset 2754 super(UnaryDataset, self).__init__(variant_tensor) 2755 2756 def _inputs(self): 2757 return [self._input_dataset] 2758 2759 2760class UnaryUnchangedStructureDataset(UnaryDataset): 2761 """Represents a unary dataset with the same input and output structure.""" 2762 2763 def __init__(self, input_dataset, variant_tensor): 2764 self._input_dataset = input_dataset 2765 super(UnaryUnchangedStructureDataset, self).__init__( 2766 input_dataset, variant_tensor) 2767 2768 @property 2769 def element_spec(self): 2770 return self._input_dataset.element_spec 2771 2772 2773class TensorDataset(DatasetSource): 2774 """A `Dataset` with a single element.""" 2775 2776 def __init__(self, element): 2777 """See `Dataset.from_tensors()` for details.""" 2778 element = structure.normalize_element(element) 2779 self._structure = structure.type_spec_from_value(element) 2780 self._tensors = structure.to_tensor_list(self._structure, element) 2781 2782 variant_tensor = gen_dataset_ops.tensor_dataset( 2783 self._tensors, 2784 output_shapes=structure.get_flat_tensor_shapes(self._structure)) 2785 super(TensorDataset, self).__init__(variant_tensor) 2786 2787 @property 2788 def element_spec(self): 2789 return self._structure 2790 2791 2792class TensorSliceDataset(DatasetSource): 2793 """A `Dataset` of slices from a dataset element.""" 2794 2795 def __init__(self, element): 2796 """See `Dataset.from_tensor_slices()` for details.""" 2797 element = structure.normalize_element(element) 2798 batched_spec = structure.type_spec_from_value(element) 2799 self._tensors = structure.to_batched_tensor_list(batched_spec, element) 2800 self._structure = nest.map_structure( 2801 lambda component_spec: component_spec._unbatch(), batched_spec) # pylint: disable=protected-access 2802 2803 batch_dim = tensor_shape.Dimension(tensor_shape.dimension_value( 2804 self._tensors[0].get_shape()[0])) 2805 for t in self._tensors[1:]: 2806 batch_dim.assert_is_compatible_with(tensor_shape.Dimension( 2807 tensor_shape.dimension_value(t.get_shape()[0]))) 2808 2809 variant_tensor = gen_dataset_ops.tensor_slice_dataset( 2810 self._tensors, 2811 output_shapes=structure.get_flat_tensor_shapes(self._structure)) 2812 super(TensorSliceDataset, self).__init__(variant_tensor) 2813 2814 @property 2815 def element_spec(self): 2816 return self._structure 2817 2818 2819class SparseTensorSliceDataset(DatasetSource): 2820 """A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows.""" 2821 2822 def __init__(self, sparse_tensor): 2823 """See `Dataset.from_sparse_tensor_slices()` for details.""" 2824 if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor): 2825 raise TypeError( 2826 "`sparse_tensor` must be a `tf.SparseTensor` object. Was {}.".format( 2827 sparse_tensor)) 2828 self._sparse_tensor = sparse_tensor 2829 2830 indices_shape = self._sparse_tensor.indices.get_shape() 2831 shape_shape = self._sparse_tensor.dense_shape.get_shape() 2832 rank = (indices_shape.dims[1] - 1).merge_with(shape_shape.dims[0] - 1) 2833 self._structure = (tensor_spec.TensorSpec([None, rank], dtypes.int64), 2834 tensor_spec.TensorSpec([None], 2835 self._sparse_tensor.dtype), 2836 tensor_spec.TensorSpec([rank], dtypes.int64)) 2837 2838 variant_tensor = gen_dataset_ops.sparse_tensor_slice_dataset( 2839 self._sparse_tensor.indices, self._sparse_tensor.values, 2840 self._sparse_tensor.dense_shape) 2841 super(SparseTensorSliceDataset, self).__init__(variant_tensor) 2842 2843 @property 2844 def element_spec(self): 2845 return self._structure 2846 2847 2848class _VariantDataset(DatasetV2): 2849 """A Dataset wrapper around a `tf.variant`-typed function argument.""" 2850 2851 def __init__(self, dataset_variant, structure): 2852 self._structure = structure 2853 super(_VariantDataset, self).__init__(dataset_variant) 2854 2855 def _inputs(self): 2856 return [] 2857 2858 @property 2859 def element_spec(self): 2860 return self._structure 2861 2862 2863class _NestedVariant(composite_tensor.CompositeTensor): 2864 2865 def __init__(self, variant_tensor, element_spec, dataset_shape): 2866 self._variant_tensor = variant_tensor 2867 self._element_spec = element_spec 2868 self._dataset_shape = dataset_shape 2869 2870 @property 2871 def _type_spec(self): 2872 return DatasetSpec(self._element_spec, self._dataset_shape) 2873 2874 2875@tf_export("data.experimental.from_variant") 2876def from_variant(variant, structure): 2877 """Constructs a dataset from the given variant and structure. 2878 2879 Args: 2880 variant: A scalar `tf.variant` tensor representing a dataset. 2881 structure: A `tf.data.experimental.Structure` object representing the 2882 structure of each element in the dataset. 2883 2884 Returns: 2885 A `tf.data.Dataset` instance. 2886 """ 2887 return _VariantDataset(variant, structure) # pylint: disable=protected-access 2888 2889 2890@tf_export("data.experimental.to_variant") 2891def to_variant(dataset): 2892 """Returns a variant representing the given dataset. 2893 2894 Args: 2895 dataset: A `tf.data.Dataset`. 2896 2897 Returns: 2898 A scalar `tf.variant` tensor representing the given dataset. 2899 """ 2900 return dataset._variant_tensor # pylint: disable=protected-access 2901 2902 2903@tf_export( 2904 "data.DatasetSpec", 2905 v1=["data.DatasetSpec", "data.experimental.DatasetStructure"]) 2906class DatasetSpec(type_spec.BatchableTypeSpec): 2907 """Type specification for `tf.data.Dataset`. 2908 2909 See `tf.TypeSpec` for more information about TensorFlow type specifications. 2910 2911 >>> dataset = tf.data.Dataset.range(3) 2912 >>> tf.data.DatasetSpec.from_value(dataset) 2913 DatasetSpec(TensorSpec(shape=(), dtype=tf.int64, name=None), TensorShape([])) 2914 """ 2915 2916 __slots__ = ["_element_spec", "_dataset_shape"] 2917 2918 def __init__(self, element_spec, dataset_shape=()): 2919 self._element_spec = element_spec 2920 self._dataset_shape = tensor_shape.as_shape(dataset_shape) 2921 2922 @property 2923 def value_type(self): 2924 return _VariantDataset 2925 2926 def _serialize(self): 2927 return (self._element_spec, self._dataset_shape) 2928 2929 @property 2930 def _component_specs(self): 2931 return tensor_spec.TensorSpec(self._dataset_shape, dtypes.variant) 2932 2933 def _to_components(self, value): 2934 return value._variant_tensor # pylint: disable=protected-access 2935 2936 def _from_components(self, components): 2937 # pylint: disable=protected-access 2938 if self._dataset_shape.ndims == 0: 2939 return _VariantDataset(components, self._element_spec) 2940 else: 2941 return _NestedVariant(components, self._element_spec, self._dataset_shape) 2942 2943 def _to_tensor_list(self, value): 2944 return [ 2945 ops.convert_to_tensor( 2946 tf_nest.map_structure(lambda x: x._variant_tensor, value)) # pylint: disable=protected-access 2947 ] 2948 2949 @staticmethod 2950 def from_value(value): 2951 """Creates a `DatasetSpec` for the given `tf.data.Dataset` value.""" 2952 return DatasetSpec(value.element_spec) # pylint: disable=protected-access 2953 2954 def _batch(self, batch_size): 2955 return DatasetSpec( 2956 self._element_spec, 2957 tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape)) 2958 2959 def _unbatch(self): 2960 if self._dataset_shape.ndims == 0: 2961 raise ValueError("Unbatching a dataset is only supported for rank >= 1") 2962 return DatasetSpec(self._element_spec, self._dataset_shape[1:]) 2963 2964 def _to_batched_tensor_list(self, value): 2965 if self._dataset_shape.ndims == 0: 2966 raise ValueError("Unbatching a dataset is only supported for rank >= 1") 2967 return self._to_tensor_list(value) 2968 2969 def _to_legacy_output_types(self): 2970 return self 2971 2972 def _to_legacy_output_shapes(self): 2973 return self 2974 2975 def _to_legacy_output_classes(self): 2976 return self 2977 2978 2979class StructuredFunctionWrapper(object): 2980 """A function wrapper that supports structured arguments and return values.""" 2981 2982 # pylint: disable=protected-access 2983 def __init__(self, 2984 func, 2985 transformation_name, 2986 dataset=None, 2987 input_classes=None, 2988 input_shapes=None, 2989 input_types=None, 2990 input_structure=None, 2991 add_to_graph=True, 2992 use_legacy_function=False, 2993 defun_kwargs=None): 2994 """Creates a new `StructuredFunctionWrapper` for the given function. 2995 2996 Args: 2997 func: A function from a nested structure to another nested structure. 2998 transformation_name: Human-readable name of the transformation in which 2999 this function is being instantiated, for error messages. 3000 dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this 3001 dataset will be assumed as the structure for `func` arguments; otherwise 3002 `input_classes`, `input_shapes`, and `input_types` must be defined. 3003 input_classes: (Optional.) A nested structure of `type`. If given, this 3004 argument defines the Python types for `func` arguments. 3005 input_shapes: (Optional.) A nested structure of `tf.TensorShape`. If 3006 given, this argument defines the shapes and structure for `func` 3007 arguments. 3008 input_types: (Optional.) A nested structure of `tf.DType`. If given, this 3009 argument defines the element types and structure for `func` arguments. 3010 input_structure: (Optional.) A `Structure` object. If given, this argument 3011 defines the element types and structure for `func` arguments. 3012 add_to_graph: (Optional.) If `True`, the function will be added to the 3013 default graph, if it exists. 3014 use_legacy_function: (Optional.) A boolean that determines whether the 3015 function be created using `tensorflow.python.eager.function.defun` 3016 (default behavior) or `tensorflow.python.framework.function.Defun` 3017 (legacy beheavior). 3018 defun_kwargs: (Optional.) A dictionary mapping string argument names to 3019 values. If supplied, will be passed to `function` as keyword arguments. 3020 3021 Raises: 3022 ValueError: If an invalid combination of `dataset`, `input_classes`, 3023 `input_shapes`, and `input_types` is passed. 3024 """ 3025 if input_structure is None: 3026 if dataset is None: 3027 if input_classes is None or input_shapes is None or input_types is None: 3028 raise ValueError("Either `dataset`, `input_structure` or all of " 3029 "`input_classes`, `input_shapes`, and `input_types` " 3030 "must be specified.") 3031 self._input_structure = structure.convert_legacy_structure( 3032 input_types, input_shapes, input_classes) 3033 else: 3034 if not (input_classes is None and input_shapes is None and 3035 input_types is None): 3036 raise ValueError("Either `dataset`, `input_structure` or all of " 3037 "`input_classes`, `input_shapes`, and `input_types` " 3038 "must be specified.") 3039 self._input_structure = dataset.element_spec 3040 else: 3041 if not (dataset is None and input_classes is None and input_shapes is None 3042 and input_types is None): 3043 raise ValueError("Either `dataset`, `input_structure`, or all of " 3044 "`input_classes`, `input_shapes`, and `input_types` " 3045 "must be specified.") 3046 self._input_structure = input_structure 3047 3048 self._func = func 3049 3050 # There is no graph to add in eager mode. 3051 add_to_graph &= not context.executing_eagerly() 3052 # There are some lifetime issues when a legacy function is not added to a 3053 # out-living graph. It's already deprecated so de-priotizing the fix. 3054 add_to_graph |= use_legacy_function 3055 3056 if defun_kwargs is None: 3057 defun_kwargs = {} 3058 3059 readable_transformation_name = transformation_name.replace( 3060 ".", "_")[:-2] if len(transformation_name) > 2 else "" 3061 3062 func_name = "_".join( 3063 [readable_transformation_name, 3064 function_utils.get_func_name(func)]) 3065 # Sanitize function name to remove symbols that interfere with graph 3066 # construction. 3067 for symbol in ["<", ">", "\\", "'", " "]: 3068 func_name = func_name.replace(symbol, "") 3069 3070 ag_ctx = autograph_ctx.control_status_ctx() 3071 3072 def _warn_if_collections(transformation_name): 3073 """Prints a warning if the given graph uses common graph collections. 3074 3075 NOTE(mrry): Currently a warning is only generated for resources. Any 3076 variables created will be automatically hoisted out to the outermost scope 3077 using `init_scope()`. Some collections (such as for control-flow contexts) 3078 are benign and should not generate a warning. 3079 3080 Args: 3081 transformation_name: A human-readable name for the transformation. 3082 """ 3083 warnings.warn("Creating resources inside a function passed to %s " 3084 "is not supported. Create each resource outside the " 3085 "function, and capture it inside the function to use it." % 3086 transformation_name, stacklevel=5) 3087 3088 def _wrapper_helper(*args): 3089 """Wrapper for passing nested structures to and from tf.data functions.""" 3090 nested_args = structure.from_compatible_tensor_list( 3091 self._input_structure, args) 3092 if not _should_unpack_args(nested_args): 3093 nested_args = (nested_args,) 3094 3095 ret = autograph.tf_convert(func, ag_ctx)(*nested_args) 3096 # If `func` returns a list of tensors, `nest.flatten()` and 3097 # `ops.convert_to_tensor()` would conspire to attempt to stack 3098 # those tensors into a single tensor, because the customized 3099 # version of `nest.flatten()` does not recurse into lists. Since 3100 # it is more likely that the list arose from returning the 3101 # result of an operation (such as `tf.numpy_function()`) that returns a 3102 # list of not-necessarily-stackable tensors, we treat the 3103 # returned value is a `tuple` instead. A user wishing to pack 3104 # the return value into a single tensor can use an explicit 3105 # `tf.stack()` before returning. 3106 if isinstance(ret, list): 3107 ret = tuple(ret) 3108 3109 try: 3110 self._output_structure = structure.type_spec_from_value(ret) 3111 except (ValueError, TypeError): 3112 six.reraise( 3113 TypeError, 3114 TypeError("Unsupported return value from function passed to " 3115 "%s: %s." % (transformation_name, ret)), 3116 sys.exc_info()[2]) 3117 return ret 3118 3119 if use_legacy_function: 3120 func_name = func_name + "_" + str(ops.uid()) 3121 3122 @function.Defun( 3123 *structure.get_flat_tensor_types(self._input_structure), 3124 func_name=func_name, 3125 **defun_kwargs) 3126 def wrapper_fn(*args): 3127 ret = _wrapper_helper(*args) 3128 # _warn_if_collections(transformation_name, ops.get_default_graph(), 0) 3129 return structure.to_tensor_list(self._output_structure, ret) 3130 3131 self._function = wrapper_fn 3132 resource_tracker = tracking.ResourceTracker() 3133 with tracking.resource_tracker_scope(resource_tracker): 3134 if add_to_graph: 3135 self._function.add_to_graph(ops.get_default_graph()) 3136 else: 3137 # Use the private method that will execute `wrapper_fn` but delay 3138 # adding it to the graph in case (e.g.) we need to rerun the function. 3139 self._function._create_definition_if_needed() 3140 if resource_tracker.resources: 3141 _warn_if_collections(transformation_name) 3142 3143 else: 3144 defun_kwargs.update({"func_name": func_name}) 3145 3146 # Note: _wrapper_helper will apply autograph based on context. 3147 @eager_function.defun_with_attributes( 3148 input_signature=structure.get_flat_tensor_specs( 3149 self._input_structure), 3150 autograph=False, 3151 attributes=defun_kwargs) 3152 def wrapper_fn(*args): # pylint: disable=missing-docstring 3153 ret = _wrapper_helper(*args) 3154 ret = structure.to_tensor_list(self._output_structure, ret) 3155 return [ops.convert_to_tensor(t) for t in ret] 3156 3157 resource_tracker = tracking.ResourceTracker() 3158 with tracking.resource_tracker_scope(resource_tracker): 3159 # TODO(b/141462134): Switch to using garbage collection. 3160 self._function = wrapper_fn.get_concrete_function() 3161 3162 if add_to_graph: 3163 self._function.add_to_graph(ops.get_default_graph()) 3164 if resource_tracker.resources: 3165 _warn_if_collections(transformation_name) 3166 3167 outer_graph_seed = ops.get_default_graph().seed 3168 if outer_graph_seed and self._function.graph.seed == outer_graph_seed: 3169 if self._function.graph._seed_used: 3170 warnings.warn( 3171 "Seed %s from outer graph might be getting used by function %s, " 3172 "if the random op has not been provided any seed. Explicitly set " 3173 "the seed in the function if this is not the intended behavior." 3174 %(outer_graph_seed, func_name), stacklevel=4) 3175 # pylint: enable=protected-access 3176 3177 @property 3178 def output_structure(self): 3179 return self._output_structure 3180 3181 @property 3182 def output_classes(self): 3183 return nest.map_structure( 3184 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 3185 self._output_structure) 3186 3187 @property 3188 def output_shapes(self): 3189 return nest.map_structure( 3190 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 3191 self._output_structure) 3192 3193 @property 3194 def output_types(self): 3195 return nest.map_structure( 3196 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 3197 self._output_structure) 3198 3199 @property 3200 def function(self): 3201 return self._function 3202 3203 3204class _GeneratorDataset(DatasetSource): 3205 """A `Dataset` that generates elements by invoking a function.""" 3206 3207 def __init__(self, init_args, init_func, next_func, finalize_func): 3208 """Constructs a `_GeneratorDataset`. 3209 3210 Args: 3211 init_args: A nested structure representing the arguments to `init_func`. 3212 init_func: A TensorFlow function that will be called on `init_args` each 3213 time a C++ iterator over this dataset is constructed. Returns a nested 3214 structure representing the "state" of the dataset. 3215 next_func: A TensorFlow function that will be called on the result of 3216 `init_func` to produce each element, and that raises `OutOfRangeError` 3217 to terminate iteration. 3218 finalize_func: A TensorFlow function that will be called on the result of 3219 `init_func` immediately before a C++ iterator over this dataset is 3220 destroyed. The return value is ignored. 3221 """ 3222 self._init_args = init_args 3223 3224 self._init_structure = structure.type_spec_from_value(init_args) 3225 3226 self._init_func = StructuredFunctionWrapper( 3227 init_func, 3228 self._transformation_name(), 3229 input_structure=self._init_structure) 3230 3231 self._next_func = StructuredFunctionWrapper( 3232 next_func, 3233 self._transformation_name(), 3234 input_structure=self._init_func.output_structure) 3235 3236 self._finalize_func = StructuredFunctionWrapper( 3237 finalize_func, 3238 self._transformation_name(), 3239 input_structure=self._init_func.output_structure) 3240 variant_tensor = gen_dataset_ops.generator_dataset( 3241 structure.to_tensor_list(self._init_structure, self._init_args) + 3242 self._init_func.function.captured_inputs, 3243 self._next_func.function.captured_inputs, 3244 self._finalize_func.function.captured_inputs, 3245 init_func=self._init_func.function, 3246 next_func=self._next_func.function, 3247 finalize_func=self._finalize_func.function, 3248 **self._flat_structure) 3249 super(_GeneratorDataset, self).__init__(variant_tensor) 3250 3251 @property 3252 def element_spec(self): 3253 return self._next_func.output_structure 3254 3255 def _transformation_name(self): 3256 return "Dataset.from_generator()" 3257 3258 3259class ZipDataset(DatasetV2): 3260 """A `Dataset` that zips its inputs together.""" 3261 3262 def __init__(self, datasets): 3263 """See `Dataset.zip()` for details.""" 3264 for ds in nest.flatten(datasets): 3265 if not isinstance(ds, DatasetV2): 3266 if isinstance(ds, list): 3267 message = ("The argument to `Dataset.zip()` must be a nested " 3268 "structure of `Dataset` objects. Nested structures do not " 3269 "support Python lists; please use a tuple instead.") 3270 else: 3271 message = ("The argument to `Dataset.zip()` must be a nested " 3272 "structure of `Dataset` objects.") 3273 raise TypeError(message) 3274 self._datasets = datasets 3275 self._structure = nest.pack_sequence_as( 3276 self._datasets, 3277 [ds.element_spec for ds in nest.flatten(self._datasets)]) 3278 variant_tensor = gen_dataset_ops.zip_dataset( 3279 [ds._variant_tensor for ds in nest.flatten(self._datasets)], 3280 **self._flat_structure) 3281 super(ZipDataset, self).__init__(variant_tensor) 3282 3283 def _inputs(self): 3284 return nest.flatten(self._datasets) 3285 3286 @property 3287 def element_spec(self): 3288 return self._structure 3289 3290 3291class ConcatenateDataset(DatasetV2): 3292 """A `Dataset` that concatenates its input with given dataset.""" 3293 3294 def __init__(self, input_dataset, dataset_to_concatenate): 3295 """See `Dataset.concatenate()` for details.""" 3296 self._input_dataset = input_dataset 3297 self._dataset_to_concatenate = dataset_to_concatenate 3298 3299 output_types = get_legacy_output_types(input_dataset) 3300 if output_types != get_legacy_output_types(dataset_to_concatenate): 3301 raise TypeError( 3302 "Two datasets to concatenate have different types %s and %s" % 3303 (output_types, get_legacy_output_types(dataset_to_concatenate))) 3304 3305 output_classes = get_legacy_output_classes(input_dataset) 3306 if output_classes != get_legacy_output_classes(dataset_to_concatenate): 3307 raise TypeError( 3308 "Two datasets to concatenate have different classes %s and %s" % 3309 (output_classes, get_legacy_output_classes(dataset_to_concatenate))) 3310 3311 input_shapes = get_legacy_output_shapes(self._input_dataset) 3312 output_shapes = nest.pack_sequence_as(input_shapes, [ 3313 ts1.most_specific_compatible_shape(ts2) 3314 for (ts1, ts2) in zip( 3315 nest.flatten(input_shapes), 3316 nest.flatten(get_legacy_output_shapes( 3317 self._dataset_to_concatenate))) 3318 ]) 3319 3320 self._structure = structure.convert_legacy_structure( 3321 output_types, output_shapes, output_classes) 3322 3323 self._input_datasets = [input_dataset, dataset_to_concatenate] 3324 # pylint: disable=protected-access 3325 variant_tensor = gen_dataset_ops.concatenate_dataset( 3326 input_dataset._variant_tensor, dataset_to_concatenate._variant_tensor, 3327 **self._flat_structure) 3328 # pylint: enable=protected-access 3329 super(ConcatenateDataset, self).__init__(variant_tensor) 3330 3331 def _inputs(self): 3332 return self._input_datasets 3333 3334 @property 3335 def element_spec(self): 3336 return self._structure 3337 3338 3339class RepeatDataset(UnaryUnchangedStructureDataset): 3340 """A `Dataset` that repeats its input several times.""" 3341 3342 def __init__(self, input_dataset, count): 3343 """See `Dataset.repeat()` for details.""" 3344 self._input_dataset = input_dataset 3345 if count is None: 3346 self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") 3347 else: 3348 self._count = ops.convert_to_tensor( 3349 count, dtype=dtypes.int64, name="count") 3350 variant_tensor = gen_dataset_ops.repeat_dataset( 3351 input_dataset._variant_tensor, # pylint: disable=protected-access 3352 count=self._count, 3353 **self._flat_structure) 3354 super(RepeatDataset, self).__init__(input_dataset, variant_tensor) 3355 3356 3357class RangeDataset(DatasetSource): 3358 """A `Dataset` of a step separated range of values.""" 3359 3360 def __init__(self, *args, **kwargs): 3361 """See `Dataset.range()` for details.""" 3362 self._parse_args(*args, **kwargs) 3363 self._structure = tensor_spec.TensorSpec([], self._output_type) 3364 variant_tensor = gen_dataset_ops.range_dataset( 3365 start=self._start, 3366 stop=self._stop, 3367 step=self._step, 3368 **self._flat_structure) 3369 super(RangeDataset, self).__init__(variant_tensor) 3370 3371 def _parse_args(self, *args, **kwargs): 3372 """Parse arguments according to the same rules as the `range()` builtin.""" 3373 if len(args) == 1: 3374 self._start = self._build_tensor(0, "start") 3375 self._stop = self._build_tensor(args[0], "stop") 3376 self._step = self._build_tensor(1, "step") 3377 elif len(args) == 2: 3378 self._start = self._build_tensor(args[0], "start") 3379 self._stop = self._build_tensor(args[1], "stop") 3380 self._step = self._build_tensor(1, "step") 3381 elif len(args) == 3: 3382 self._start = self._build_tensor(args[0], "start") 3383 self._stop = self._build_tensor(args[1], "stop") 3384 self._step = self._build_tensor(args[2], "step") 3385 else: 3386 raise ValueError("Invalid arguments to RangeDataset: %s" % str(args)) 3387 if "output_type" in kwargs: 3388 self._output_type = kwargs["output_type"] 3389 else: 3390 self._output_type = dtypes.int64 3391 3392 def _build_tensor(self, int64_value, name): 3393 return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name) 3394 3395 @property 3396 def element_spec(self): 3397 return self._structure 3398 3399 3400class _MemoryCacheDeleter(object): 3401 """An object which cleans up an anonymous memory cache resource. 3402 3403 An alternative to defining a __del__ method on an object. Even if the parent 3404 object is part of a reference cycle, the cycle will be collectable. 3405 """ 3406 3407 def __init__(self, handle, device, deleter): 3408 self._deleter = deleter 3409 self._handle = handle 3410 self._device = device 3411 self._eager_mode = context.executing_eagerly() 3412 3413 def __del__(self): 3414 with ops.device(self._device): 3415 # Make sure the resource is deleted in the same mode as it was created in. 3416 if self._eager_mode: 3417 with context.eager_mode(): 3418 gen_dataset_ops.delete_memory_cache( 3419 handle=self._handle, deleter=self._deleter) 3420 else: 3421 with context.graph_mode(): 3422 gen_dataset_ops.delete_memory_cache( 3423 handle=self._handle, deleter=self._deleter) 3424 3425 3426class _MemoryCache(object): 3427 """Represents a memory cache resource.""" 3428 3429 def __init__(self): 3430 super(_MemoryCache, self).__init__() 3431 self._device = context.context().device_name 3432 self._handle, self._deleter = (gen_dataset_ops.anonymous_memory_cache()) 3433 self._resource_deleter = _MemoryCacheDeleter( 3434 handle=self._handle, device=self._device, deleter=self._deleter) 3435 3436 @property 3437 def handle(self): 3438 return self._handle 3439 3440 3441class CacheDataset(UnaryUnchangedStructureDataset): 3442 """A `Dataset` that caches elements of its input.""" 3443 3444 def __init__(self, input_dataset, filename): 3445 """See `Dataset.cache()` for details.""" 3446 self._input_dataset = input_dataset 3447 self._filename = ops.convert_to_tensor( 3448 filename, dtype=dtypes.string, name="filename") 3449 if tf2.enabled() and (context.executing_eagerly() or 3450 ops.get_default_graph()._building_function): # pylint: disable=protected-access 3451 self._cache = _MemoryCache() 3452 variant_tensor = gen_dataset_ops.cache_dataset_v2( 3453 input_dataset._variant_tensor, # pylint: disable=protected-access 3454 filename=self._filename, 3455 cache=self._cache.handle, 3456 **self._flat_structure) 3457 else: 3458 variant_tensor = gen_dataset_ops.cache_dataset( 3459 input_dataset._variant_tensor, # pylint: disable=protected-access 3460 filename=self._filename, 3461 **self._flat_structure) 3462 super(CacheDataset, self).__init__(input_dataset, variant_tensor) 3463 3464 3465class _RandomSeedGeneratorDeleter(object): 3466 """An object which cleans up an anonymous random seed generator resource. 3467 3468 An alternative to defining a __del__ method on an object. Even if the parent 3469 object is part of a reference cycle, the cycle will be collectable. 3470 """ 3471 3472 def __init__(self, handle, device, deleter): 3473 self._deleter = deleter 3474 self._handle = handle 3475 self._device = device 3476 self._eager_mode = context.executing_eagerly() 3477 3478 def __del__(self): 3479 with ops.device(self._device): 3480 # Make sure the resource is deleted in the same mode as it was created in. 3481 if self._eager_mode: 3482 with context.eager_mode(): 3483 gen_dataset_ops.delete_random_seed_generator( 3484 handle=self._handle, deleter=self._deleter) 3485 else: 3486 with context.graph_mode(): 3487 gen_dataset_ops.delete_random_seed_generator( 3488 handle=self._handle, deleter=self._deleter) 3489 3490 3491class _RandomSeedGenerator(object): 3492 """Represents a random seed generator resource.""" 3493 3494 def __init__(self, seed, seed2): 3495 super(_RandomSeedGenerator, self).__init__() 3496 self._device = context.context().device_name 3497 self._handle, self._deleter = ( 3498 gen_dataset_ops.anonymous_random_seed_generator(seed=seed, seed2=seed2)) 3499 self._resource_deleter = _RandomSeedGeneratorDeleter( 3500 handle=self._handle, device=self._device, deleter=self._deleter) 3501 3502 @property 3503 def handle(self): 3504 return self._handle 3505 3506 3507class ShuffleDataset(UnaryUnchangedStructureDataset): 3508 """A `Dataset` that randomly shuffles the elements of its input.""" 3509 3510 def __init__(self, 3511 input_dataset, 3512 buffer_size, 3513 seed=None, 3514 reshuffle_each_iteration=None): 3515 """Randomly shuffles the elements of this dataset. 3516 3517 Args: 3518 input_dataset: The input dataset. 3519 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 3520 elements from this dataset from which the new dataset will sample. 3521 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 3522 seed that will be used to create the distribution. See 3523 `tf.random.set_seed` for behavior. 3524 reshuffle_each_iteration: (Optional.) A boolean, which if true indicates 3525 that the dataset should be pseudorandomly reshuffled each time it is 3526 iterated over. (Defaults to `True`.) 3527 3528 Returns: 3529 A `Dataset`. 3530 3531 Raises: 3532 ValueError: if invalid arguments are provided. 3533 """ 3534 self._input_dataset = input_dataset 3535 self._buffer_size = ops.convert_to_tensor( 3536 buffer_size, dtype=dtypes.int64, name="buffer_size") 3537 self._seed, self._seed2 = random_seed.get_seed(seed) 3538 3539 if reshuffle_each_iteration is None: 3540 self._reshuffle_each_iteration = True 3541 else: 3542 self._reshuffle_each_iteration = reshuffle_each_iteration 3543 3544 if tf2.enabled() and self._reshuffle_each_iteration and ( 3545 context.executing_eagerly() or 3546 ops.get_default_graph()._building_function): # pylint: disable=protected-access 3547 self._seed_generator = _RandomSeedGenerator(self._seed, self._seed2) 3548 variant_tensor = gen_dataset_ops.shuffle_dataset_v2( 3549 input_dataset._variant_tensor, # pylint: disable=protected-access 3550 buffer_size=self._buffer_size, 3551 seed_generator=self._seed_generator.handle, 3552 **self._flat_structure) 3553 else: 3554 variant_tensor = gen_dataset_ops.shuffle_dataset( 3555 input_dataset._variant_tensor, # pylint: disable=protected-access 3556 buffer_size=self._buffer_size, 3557 seed=self._seed, 3558 seed2=self._seed2, 3559 reshuffle_each_iteration=self._reshuffle_each_iteration, 3560 **self._flat_structure) 3561 super(ShuffleDataset, self).__init__(input_dataset, variant_tensor) 3562 3563 3564class TakeDataset(UnaryUnchangedStructureDataset): 3565 """A `Dataset` containing the first `count` elements from its input.""" 3566 3567 def __init__(self, input_dataset, count): 3568 """See `Dataset.take()` for details.""" 3569 self._input_dataset = input_dataset 3570 self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") 3571 variant_tensor = gen_dataset_ops.take_dataset( 3572 input_dataset._variant_tensor, # pylint: disable=protected-access 3573 count=self._count, 3574 **self._flat_structure) 3575 super(TakeDataset, self).__init__(input_dataset, variant_tensor) 3576 3577 3578class SkipDataset(UnaryUnchangedStructureDataset): 3579 """A `Dataset` skipping the first `count` elements from its input.""" 3580 3581 def __init__(self, input_dataset, count): 3582 """See `Dataset.skip()` for details.""" 3583 self._input_dataset = input_dataset 3584 self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") 3585 variant_tensor = gen_dataset_ops.skip_dataset( 3586 input_dataset._variant_tensor, # pylint: disable=protected-access 3587 count=self._count, 3588 **self._flat_structure) 3589 super(SkipDataset, self).__init__(input_dataset, variant_tensor) 3590 3591 3592class ShardDataset(UnaryUnchangedStructureDataset): 3593 """A `Dataset` for sharding its input.""" 3594 3595 def __init__(self, input_dataset, num_shards, index): 3596 """See `Dataset.shard()` for details.""" 3597 self._input_dataset = input_dataset 3598 self._num_shards = ops.convert_to_tensor( 3599 num_shards, dtype=dtypes.int64, name="num_shards") 3600 self._index = ops.convert_to_tensor(index, dtype=dtypes.int64, name="index") 3601 variant_tensor = gen_dataset_ops.shard_dataset( 3602 input_dataset._variant_tensor, # pylint: disable=protected-access 3603 num_shards=self._num_shards, 3604 index=self._index, 3605 **self._flat_structure) 3606 super(ShardDataset, self).__init__(input_dataset, variant_tensor) 3607 3608 3609class BatchDataset(UnaryDataset): 3610 """A `Dataset` that batches contiguous elements from its input.""" 3611 3612 def __init__(self, input_dataset, batch_size, drop_remainder): 3613 """See `Dataset.batch()` for details.""" 3614 self._input_dataset = input_dataset 3615 self._batch_size = ops.convert_to_tensor( 3616 batch_size, dtype=dtypes.int64, name="batch_size") 3617 self._drop_remainder = ops.convert_to_tensor( 3618 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 3619 3620 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder) 3621 # pylint: disable=protected-access 3622 if constant_drop_remainder: 3623 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) 3624 # or `False` (explicitly retaining the remainder). 3625 # pylint: disable=g-long-lambda 3626 constant_batch_size = tensor_util.constant_value(self._batch_size) 3627 self._structure = nest.map_structure( 3628 lambda component_spec: component_spec._batch(constant_batch_size), 3629 input_dataset.element_spec) 3630 else: 3631 self._structure = nest.map_structure( 3632 lambda component_spec: component_spec._batch(None), 3633 input_dataset.element_spec) 3634 variant_tensor = gen_dataset_ops.batch_dataset_v2( 3635 input_dataset._variant_tensor, 3636 batch_size=self._batch_size, 3637 drop_remainder=self._drop_remainder, 3638 **self._flat_structure) 3639 super(BatchDataset, self).__init__(input_dataset, variant_tensor) 3640 3641 @property 3642 def element_spec(self): 3643 return self._structure 3644 3645 3646class _NumpyIterator(object): 3647 """Iterator over a dataset with elements converted to numpy.""" 3648 3649 def __init__(self, dataset): 3650 self._iterator = iter(dataset) 3651 3652 def __iter__(self): 3653 return self 3654 3655 def next(self): 3656 return nest.map_structure(lambda x: x.numpy(), next(self._iterator)) 3657 3658 def __next__(self): 3659 return self.next() 3660 3661 3662class _VariantTracker(tracking.CapturableResource): 3663 """Allows export of functions capturing a Dataset in SavedModels. 3664 3665 When saving a SavedModel, `tf.saved_model.save` traverses the object 3666 graph. Since Datasets reference _VariantTracker objects, that traversal will 3667 find a _VariantTracker for each Dataset and so know how to save and restore 3668 functions which reference the Dataset's variant Tensor. 3669 """ 3670 3671 def __init__(self, variant_tensor, resource_creator): 3672 """Record that `variant_tensor` is associated with `resource_creator`. 3673 3674 Args: 3675 variant_tensor: The variant-dtype Tensor associated with the Dataset. This 3676 Tensor will be a captured input to functions which use the Dataset, and 3677 is used by saving code to identify the corresponding _VariantTracker. 3678 resource_creator: A zero-argument function which creates a new 3679 variant-dtype Tensor. This function will be included in SavedModels and 3680 run to re-create the Dataset's variant Tensor on restore. 3681 """ 3682 super(_VariantTracker, self).__init__(device="CPU") 3683 self._resource_handle = variant_tensor 3684 self._create_resource = resource_creator 3685 3686 3687def _is_padded_shape_compatible_with(padded_shape, input_component_shape): 3688 """Returns `True` if `input_component_shape` can be padded to `padded_shape`. 3689 3690 Args: 3691 padded_shape: A `tf.TensorShape`. 3692 input_component_shape: A `tf.TensorShape`. 3693 3694 Returns: 3695 `True` if `input_component_shape` can be padded to `padded_shape`, otherwise 3696 `False`. 3697 """ 3698 3699 if padded_shape.dims is None or input_component_shape.dims is None: 3700 return True 3701 if len(padded_shape.dims) != len(input_component_shape.dims): 3702 return False 3703 for padded_dim, input_dim in zip( 3704 padded_shape.dims, input_component_shape.dims): 3705 if (padded_dim.value is not None and input_dim.value is not None 3706 and padded_dim.value < input_dim.value): 3707 return False 3708 return True 3709 3710 3711def _padded_shape_to_tensor(padded_shape, input_component_shape): 3712 """Converts `padded_shape` to a `tf.Tensor` representing that shape. 3713 3714 Args: 3715 padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python 3716 sequence, or a 1-D `tf.Tensor` of `tf.int64` elements. 3717 input_component_shape: A `tf.TensorShape`, with which `padded_shape` must 3718 be compatible. 3719 3720 Returns: 3721 A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`. 3722 3723 Raises: 3724 ValueError: If `padded_shape` is not a shape or not compatible with 3725 `input_component_shape`. 3726 TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor. 3727 """ 3728 try: 3729 # Try to convert the `padded_shape` to a `tf.TensorShape` 3730 padded_shape_as_shape = tensor_shape.as_shape(padded_shape) 3731 # We will return the "canonical" tensor representation, which uses 3732 # `-1` in place of `None`. 3733 ret = ops.convert_to_tensor( 3734 [dim if dim is not None else -1 3735 for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64) 3736 except (TypeError, ValueError): 3737 # The argument was not trivially convertible to a 3738 # `tf.TensorShape`, so fall back on the conversion to tensor 3739 # machinery. 3740 ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64) 3741 if ret.shape.dims is not None and len(ret.shape.dims) != 1: 3742 six.reraise(ValueError, ValueError( 3743 "Padded shape %s must be a 1-D tensor of tf.int64 values, but its " 3744 "shape was %s." % (padded_shape, ret.shape)), sys.exc_info()[2]) 3745 if ret.dtype != dtypes.int64: 3746 six.reraise( 3747 TypeError, 3748 TypeError( 3749 "Padded shape %s must be a 1-D tensor of tf.int64 values, but " 3750 "its element type was %s." % (padded_shape, ret.dtype.name)), 3751 sys.exc_info()[2]) 3752 padded_shape_as_shape = tensor_util.constant_value_as_shape(ret) 3753 3754 if not _is_padded_shape_compatible_with(padded_shape_as_shape, 3755 input_component_shape): 3756 raise ValueError("The padded shape %s is not compatible with the " 3757 "corresponding input component shape %s." 3758 % (padded_shape_as_shape, input_component_shape)) 3759 3760 return ret 3761 3762 3763def _padding_value_to_tensor(value, output_type): 3764 """Converts the padding value to a tensor. 3765 3766 Args: 3767 value: The padding value. 3768 output_type: Its expected dtype. 3769 3770 Returns: 3771 A scalar `Tensor`. 3772 3773 Raises: 3774 ValueError: if the padding value is not a scalar. 3775 TypeError: if the padding value's type does not match `output_type`. 3776 """ 3777 value = ops.convert_to_tensor(value, name="padding_value") 3778 if not value.shape.is_compatible_with(tensor_shape.TensorShape([])): 3779 raise ValueError("Padding value should be a scalar, but is not: %s" % value) 3780 if value.dtype != output_type: 3781 raise TypeError("Padding value tensor (%s) does not match output type: %s" % 3782 (value, output_type)) 3783 return value 3784 3785 3786def _padding_values_or_default(padding_values, input_dataset): 3787 """Returns padding values with None elements replaced with default values.""" 3788 def make_zero(t): 3789 if t.base_dtype == dtypes.string: 3790 return "" 3791 elif t.base_dtype == dtypes.variant: 3792 error_msg = ("Unable to create padding for field of type 'variant' " 3793 "because t.base_type == dtypes.variant == " 3794 "{}.".format( 3795 t.base_dtype)) 3796 raise TypeError(error_msg) 3797 else: 3798 return np.zeros_like(t.as_numpy_dtype()) 3799 def value_or_default(value, default): 3800 return default if value is None else value 3801 3802 default_padding = nest.map_structure(make_zero, 3803 get_legacy_output_types(input_dataset)) 3804 return nest.map_structure_up_to(padding_values, value_or_default, 3805 padding_values, default_padding) 3806 3807 3808class PaddedBatchDataset(UnaryDataset): 3809 """A `Dataset` that batches and pads contiguous elements from its input.""" 3810 3811 def __init__(self, input_dataset, batch_size, padded_shapes, padding_values, 3812 drop_remainder): 3813 """See `Dataset.batch()` for details.""" 3814 self._input_dataset = input_dataset 3815 if sparse.any_sparse(get_legacy_output_classes(input_dataset)): 3816 # TODO(b/63669786): support batching of sparse tensors 3817 raise TypeError( 3818 "Batching of padded sparse tensors is not currently supported") 3819 self._input_dataset = input_dataset 3820 self._batch_size = ops.convert_to_tensor( 3821 batch_size, dtype=dtypes.int64, name="batch_size") 3822 padding_values = _padding_values_or_default(padding_values, input_dataset) 3823 3824 input_shapes = get_legacy_output_shapes(input_dataset) 3825 flat_padded_shapes = nest.flatten_up_to(input_shapes, padded_shapes) 3826 3827 flat_padded_shapes_as_tensors = [] 3828 3829 for input_component_shape, padded_shape in zip( 3830 nest.flatten(input_shapes), flat_padded_shapes): 3831 flat_padded_shapes_as_tensors.append( 3832 _padded_shape_to_tensor(padded_shape, input_component_shape)) 3833 3834 self._padded_shapes = nest.pack_sequence_as(input_shapes, 3835 flat_padded_shapes_as_tensors) 3836 3837 self._padding_values = nest.map_structure_up_to( 3838 input_shapes, _padding_value_to_tensor, padding_values, 3839 get_legacy_output_types(input_dataset)) 3840 self._drop_remainder = ops.convert_to_tensor( 3841 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 3842 3843 def _padded_shape_to_batch_shape(s): 3844 return tensor_shape.TensorShape([ 3845 tensor_util.constant_value(self._batch_size) 3846 if smart_cond.smart_constant_value(self._drop_remainder) else None 3847 ]).concatenate(tensor_util.constant_value_as_shape(s)) 3848 3849 output_shapes = nest.map_structure( 3850 _padded_shape_to_batch_shape, self._padded_shapes) 3851 self._structure = structure.convert_legacy_structure( 3852 get_legacy_output_types(self._input_dataset), output_shapes, 3853 get_legacy_output_classes(self._input_dataset)) 3854 3855 # pylint: disable=protected-access 3856 # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018. 3857 if smart_cond.smart_constant_value(self._drop_remainder) is False: 3858 variant_tensor = gen_dataset_ops.padded_batch_dataset( 3859 input_dataset._variant_tensor, # pylint: disable=protected-access 3860 batch_size=self._batch_size, 3861 padded_shapes=[ 3862 ops.convert_to_tensor(s, dtype=dtypes.int64) 3863 for s in nest.flatten(self._padded_shapes) 3864 ], 3865 padding_values=nest.flatten(self._padding_values), 3866 output_shapes=structure.get_flat_tensor_shapes(self._structure)) 3867 else: 3868 variant_tensor = gen_dataset_ops.padded_batch_dataset_v2( 3869 input_dataset._variant_tensor, # pylint: disable=protected-access 3870 batch_size=self._batch_size, 3871 padded_shapes=[ 3872 ops.convert_to_tensor(s, dtype=dtypes.int64) 3873 for s in nest.flatten(self._padded_shapes) 3874 ], 3875 padding_values=nest.flatten(self._padding_values), 3876 drop_remainder=self._drop_remainder, 3877 output_shapes=structure.get_flat_tensor_shapes(self._structure)) 3878 super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor) 3879 3880 @property 3881 def element_spec(self): 3882 return self._structure 3883 3884 3885def _should_unpack_args(args): 3886 """Returns `True` if `args` should be `*args` when passed to a callable.""" 3887 return type(args) is tuple # pylint: disable=unidiomatic-typecheck 3888 3889 3890class MapDataset(UnaryDataset): 3891 """A `Dataset` that maps a function over elements in its input.""" 3892 3893 def __init__(self, 3894 input_dataset, 3895 map_func, 3896 use_inter_op_parallelism=True, 3897 preserve_cardinality=False, 3898 use_legacy_function=False): 3899 """See `Dataset.map()` for details.""" 3900 self._input_dataset = input_dataset 3901 self._use_inter_op_parallelism = use_inter_op_parallelism 3902 self._preserve_cardinality = preserve_cardinality 3903 self._map_func = StructuredFunctionWrapper( 3904 map_func, 3905 self._transformation_name(), 3906 dataset=input_dataset, 3907 use_legacy_function=use_legacy_function) 3908 variant_tensor = gen_dataset_ops.map_dataset( 3909 input_dataset._variant_tensor, # pylint: disable=protected-access 3910 self._map_func.function.captured_inputs, 3911 f=self._map_func.function, 3912 use_inter_op_parallelism=self._use_inter_op_parallelism, 3913 preserve_cardinality=self._preserve_cardinality, 3914 **self._flat_structure) 3915 super(MapDataset, self).__init__(input_dataset, variant_tensor) 3916 3917 def _functions(self): 3918 return [self._map_func] 3919 3920 @property 3921 def element_spec(self): 3922 return self._map_func.output_structure 3923 3924 def _transformation_name(self): 3925 return "Dataset.map()" 3926 3927 3928class ParallelMapDataset(UnaryDataset): 3929 """A `Dataset` that maps a function over elements in its input in parallel.""" 3930 3931 def __init__(self, 3932 input_dataset, 3933 map_func, 3934 num_parallel_calls, 3935 use_inter_op_parallelism=True, 3936 preserve_cardinality=False, 3937 use_legacy_function=False): 3938 """See `Dataset.map()` for details.""" 3939 self._input_dataset = input_dataset 3940 self._use_inter_op_parallelism = use_inter_op_parallelism 3941 self._map_func = StructuredFunctionWrapper( 3942 map_func, 3943 self._transformation_name(), 3944 dataset=input_dataset, 3945 use_legacy_function=use_legacy_function) 3946 self._num_parallel_calls = ops.convert_to_tensor( 3947 num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls") 3948 self._preserve_cardinality = preserve_cardinality 3949 variant_tensor = gen_dataset_ops.parallel_map_dataset( 3950 input_dataset._variant_tensor, # pylint: disable=protected-access 3951 self._map_func.function.captured_inputs, 3952 f=self._map_func.function, 3953 num_parallel_calls=self._num_parallel_calls, 3954 use_inter_op_parallelism=self._use_inter_op_parallelism, 3955 preserve_cardinality=self._preserve_cardinality, 3956 **self._flat_structure) 3957 super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor) 3958 3959 def _functions(self): 3960 return [self._map_func] 3961 3962 @property 3963 def element_spec(self): 3964 return self._map_func.output_structure 3965 3966 def _transformation_name(self): 3967 return "Dataset.map()" 3968 3969 3970class FlatMapDataset(UnaryDataset): 3971 """A `Dataset` that maps a function over its input and flattens the result.""" 3972 3973 def __init__(self, input_dataset, map_func): 3974 """See `Dataset.flat_map()` for details.""" 3975 self._input_dataset = input_dataset 3976 self._map_func = StructuredFunctionWrapper( 3977 map_func, self._transformation_name(), dataset=input_dataset) 3978 if not isinstance(self._map_func.output_structure, DatasetSpec): 3979 raise TypeError( 3980 "`map_func` must return a `Dataset` object. Got {}".format( 3981 type(self._map_func.output_structure))) 3982 self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access 3983 variant_tensor = gen_dataset_ops.flat_map_dataset( 3984 input_dataset._variant_tensor, # pylint: disable=protected-access 3985 self._map_func.function.captured_inputs, 3986 f=self._map_func.function, 3987 **self._flat_structure) 3988 super(FlatMapDataset, self).__init__(input_dataset, variant_tensor) 3989 3990 def _functions(self): 3991 return [self._map_func] 3992 3993 @property 3994 def element_spec(self): 3995 return self._structure 3996 3997 def _transformation_name(self): 3998 return "Dataset.flat_map()" 3999 4000 4001class InterleaveDataset(UnaryDataset): 4002 """A `Dataset` that interleaves the result of transformed inputs.""" 4003 4004 def __init__(self, input_dataset, map_func, cycle_length, block_length): 4005 """See `Dataset.interleave()` for details.""" 4006 self._input_dataset = input_dataset 4007 self._map_func = StructuredFunctionWrapper( 4008 map_func, self._transformation_name(), dataset=input_dataset) 4009 if not isinstance(self._map_func.output_structure, DatasetSpec): 4010 raise TypeError( 4011 "`map_func` must return a `Dataset` object. Got {}".format( 4012 type(self._map_func.output_structure))) 4013 self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access 4014 self._cycle_length = ops.convert_to_tensor( 4015 cycle_length, dtype=dtypes.int64, name="cycle_length") 4016 self._block_length = ops.convert_to_tensor( 4017 block_length, dtype=dtypes.int64, name="block_length") 4018 4019 variant_tensor = gen_dataset_ops.interleave_dataset( 4020 input_dataset._variant_tensor, # pylint: disable=protected-access 4021 self._map_func.function.captured_inputs, # pylint: disable=protected-access 4022 self._cycle_length, 4023 self._block_length, 4024 f=self._map_func.function, 4025 **self._flat_structure) 4026 super(InterleaveDataset, self).__init__(input_dataset, variant_tensor) 4027 4028 def _functions(self): 4029 return [self._map_func] 4030 4031 @property 4032 def element_spec(self): 4033 return self._structure 4034 4035 def _transformation_name(self): 4036 return "Dataset.interleave()" 4037 4038 4039class ParallelInterleaveDataset(UnaryDataset): 4040 """A `Dataset` that maps a function over its input and interleaves the result.""" 4041 4042 def __init__(self, 4043 input_dataset, 4044 map_func, 4045 cycle_length, 4046 block_length, 4047 num_parallel_calls, 4048 deterministic=None): 4049 """See `Dataset.interleave()` for details.""" 4050 self._input_dataset = input_dataset 4051 self._map_func = StructuredFunctionWrapper( 4052 map_func, self._transformation_name(), dataset=input_dataset) 4053 if not isinstance(self._map_func.output_structure, DatasetSpec): 4054 raise TypeError( 4055 "`map_func` must return a `Dataset` object. Got {}".format( 4056 type(self._map_func.output_structure))) 4057 self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access 4058 self._cycle_length = ops.convert_to_tensor( 4059 cycle_length, dtype=dtypes.int64, name="cycle_length") 4060 self._block_length = ops.convert_to_tensor( 4061 block_length, dtype=dtypes.int64, name="block_length") 4062 self._num_parallel_calls = ops.convert_to_tensor( 4063 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 4064 if deterministic is None: 4065 deterministic_string = "default" 4066 elif deterministic: 4067 deterministic_string = "true" 4068 else: 4069 deterministic_string = "false" 4070 4071 if deterministic is not None or compat.forward_compatible(2020, 2, 20): 4072 variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v3( 4073 input_dataset._variant_tensor, # pylint: disable=protected-access 4074 self._map_func.function.captured_inputs, # pylint: disable=protected-access 4075 self._cycle_length, 4076 self._block_length, 4077 self._num_parallel_calls, 4078 f=self._map_func.function, 4079 deterministic=deterministic_string, 4080 **self._flat_structure) 4081 else: 4082 variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v2( 4083 input_dataset._variant_tensor, # pylint: disable=protected-access 4084 self._map_func.function.captured_inputs, # pylint: disable=protected-access 4085 self._cycle_length, 4086 self._block_length, 4087 self._num_parallel_calls, 4088 f=self._map_func.function, 4089 **self._flat_structure) 4090 super(ParallelInterleaveDataset, self).__init__(input_dataset, 4091 variant_tensor) 4092 4093 def _functions(self): 4094 return [self._map_func] 4095 4096 @property 4097 def element_spec(self): 4098 return self._structure 4099 4100 def _transformation_name(self): 4101 return "Dataset.interleave()" 4102 4103 4104class FilterDataset(UnaryUnchangedStructureDataset): 4105 """A `Dataset` that filters its input according to a predicate function.""" 4106 4107 def __init__(self, input_dataset, predicate, use_legacy_function=False): 4108 """See `Dataset.filter()` for details.""" 4109 self._input_dataset = input_dataset 4110 wrapped_func = StructuredFunctionWrapper( 4111 predicate, 4112 self._transformation_name(), 4113 dataset=input_dataset, 4114 use_legacy_function=use_legacy_function) 4115 if not wrapped_func.output_structure.is_compatible_with( 4116 tensor_spec.TensorSpec([], dtypes.bool)): 4117 error_msg = ("`predicate` return type must be convertible to a scalar " 4118 "boolean tensor. Was {}.").format( 4119 wrapped_func.output_structure) 4120 raise ValueError(error_msg) 4121 self._predicate = wrapped_func 4122 variant_tensor = gen_dataset_ops.filter_dataset( 4123 input_dataset._variant_tensor, # pylint: disable=protected-access 4124 other_arguments=self._predicate.function.captured_inputs, 4125 predicate=self._predicate.function, 4126 **self._flat_structure) 4127 super(FilterDataset, self).__init__(input_dataset, variant_tensor) 4128 4129 def _functions(self): 4130 return [self._predicate] 4131 4132 def _transformation_name(self): 4133 return "Dataset.filter()" 4134 4135 4136class PrefetchDataset(UnaryUnchangedStructureDataset): 4137 """A `Dataset` that asynchronously prefetches its input.""" 4138 4139 def __init__(self, input_dataset, buffer_size, slack_period=None): 4140 """See `Dataset.prefetch()` for details. 4141 4142 Args: 4143 input_dataset: The input dataset. 4144 buffer_size: See `Dataset.prefetch()` for details. 4145 slack_period: (Optional.) An integer. If non-zero, determines the number 4146 of GetNext calls before injecting slack into the execution. This may 4147 reduce CPU contention at the start of a step. Note that a tensorflow 4148 user should not have to set this manually; enable this behavior 4149 automatically via `tf.data.Options.experimental_slack` instead. Defaults 4150 to None. 4151 """ 4152 self._input_dataset = input_dataset 4153 if buffer_size is None: 4154 buffer_size = -1 # This is the sentinel for auto-tuning. 4155 self._buffer_size = ops.convert_to_tensor( 4156 buffer_size, dtype=dtypes.int64, name="buffer_size") 4157 variant_tensor = gen_dataset_ops.prefetch_dataset( 4158 input_dataset._variant_tensor, # pylint: disable=protected-access 4159 buffer_size=self._buffer_size, 4160 slack_period=slack_period, 4161 **self._flat_structure) 4162 super(PrefetchDataset, self).__init__(input_dataset, variant_tensor) 4163 4164 4165class WindowDataset(UnaryDataset): 4166 """A dataset that creates window datasets from the input elements.""" 4167 4168 def __init__(self, input_dataset, size, shift, stride, drop_remainder): 4169 """See `window_dataset()` for more details.""" 4170 self._input_dataset = input_dataset 4171 self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size") 4172 self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift") 4173 self._stride = ops.convert_to_tensor( 4174 stride, dtype=dtypes.int64, name="stride") 4175 self._drop_remainder = ops.convert_to_tensor( 4176 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 4177 self._structure = nest.pack_sequence_as( 4178 get_legacy_output_classes(input_dataset), [ 4179 DatasetSpec( # pylint: disable=g-complex-comprehension 4180 structure.convert_legacy_structure( 4181 output_type, output_shape, output_class)) 4182 for output_class, output_shape, output_type in zip( 4183 nest.flatten(get_legacy_output_classes(input_dataset)), 4184 nest.flatten(get_legacy_output_shapes(input_dataset)), 4185 nest.flatten(get_legacy_output_types(input_dataset))) 4186 ]) 4187 variant_tensor = gen_dataset_ops.window_dataset( 4188 input_dataset._variant_tensor, # pylint: disable=protected-access 4189 self._size, 4190 self._shift, 4191 self._stride, 4192 self._drop_remainder, 4193 **self._flat_structure) 4194 super(WindowDataset, self).__init__(input_dataset, variant_tensor) 4195 4196 @property 4197 def element_spec(self): 4198 return self._structure 4199 4200 4201class _OptionsDataset(UnaryUnchangedStructureDataset): 4202 """An identity `Dataset` that stores options.""" 4203 4204 def __init__(self, input_dataset, options): 4205 self._input_dataset = input_dataset 4206 self._options = input_dataset.options() 4207 if self._options: 4208 self._options = self._options.merge(options) 4209 else: 4210 self._options = options 4211 variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access 4212 super(_OptionsDataset, self).__init__(input_dataset, variant_tensor) 4213 4214 def options(self): 4215 return self._options 4216 4217 4218class _ModelDataset(UnaryUnchangedStructureDataset): 4219 """A `Dataset` that acts as an identity, and models performance.""" 4220 4221 def __init__(self, input_dataset, algorithm, cpu_budget): 4222 self._input_dataset = input_dataset 4223 variant_tensor = gen_dataset_ops.model_dataset( 4224 input_dataset._variant_tensor, # pylint: disable=protected-access 4225 algorithm=algorithm.value, 4226 cpu_budget=cpu_budget, 4227 **self._flat_structure) 4228 super(_ModelDataset, self).__init__(input_dataset, variant_tensor) 4229 4230 4231class _OptimizeDataset(UnaryUnchangedStructureDataset): 4232 """A `Dataset` that acts as an identity, and applies optimizations.""" 4233 4234 def __init__(self, input_dataset, optimizations, optimization_configs=None): 4235 self._input_dataset = input_dataset 4236 if optimizations is None: 4237 optimizations = [] 4238 if optimization_configs is None: 4239 optimization_configs = [] 4240 self._optimizations = ops.convert_to_tensor( 4241 optimizations, dtype=dtypes.string, name="optimizations") 4242 variant_tensor = gen_dataset_ops.optimize_dataset( 4243 input_dataset._variant_tensor, # pylint: disable=protected-access 4244 self._optimizations, 4245 optimization_configs=optimization_configs, 4246 **self._flat_structure) 4247 super(_OptimizeDataset, self).__init__(input_dataset, variant_tensor) 4248 4249 4250class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset): 4251 """A `Dataset` that acts as an identity, and sets a stats aggregator.""" 4252 4253 def __init__(self, input_dataset, aggregator, prefix, counter_prefix): 4254 self._input_dataset = input_dataset 4255 self._stats_aggregator = aggregator 4256 self._prefix = prefix 4257 self._counter_prefix = counter_prefix 4258 variant_tensor = ged_ops.set_stats_aggregator_dataset( 4259 input_dataset._variant_tensor, # pylint: disable=protected-access 4260 self._stats_aggregator._resource, # pylint: disable=protected-access 4261 self._prefix, 4262 self._counter_prefix, 4263 **self._flat_structure) 4264 super(_SetStatsAggregatorDataset, self).__init__(input_dataset, 4265 variant_tensor) 4266 4267 4268class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset): 4269 """A `Dataset` that acts as an identity, overriding intra-op parallelism.""" 4270 4271 def __init__(self, input_dataset, max_intra_op_parallelism): 4272 self._input_dataset = input_dataset 4273 self._max_intra_op_parallelism = ops.convert_to_tensor( 4274 max_intra_op_parallelism, 4275 dtype=dtypes.int64, 4276 name="max_intra_op_parallelism") 4277 variant_tensor = ged_ops.max_intra_op_parallelism_dataset( 4278 input_dataset._variant_tensor, # pylint: disable=protected-access 4279 self._max_intra_op_parallelism, 4280 **self._flat_structure) 4281 super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset, 4282 variant_tensor) 4283 4284 4285class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset): 4286 """A `Dataset` that acts as an identity, setting a private threadpool.""" 4287 4288 def __init__(self, input_dataset, num_threads): 4289 self._input_dataset = input_dataset 4290 self._num_threads = ops.convert_to_tensor( 4291 num_threads, dtype=dtypes.int64, name="num_threads") 4292 variant_tensor = ged_ops.private_thread_pool_dataset( 4293 input_dataset._variant_tensor, # pylint: disable=protected-access 4294 self._num_threads, 4295 **self._flat_structure) 4296 super(_PrivateThreadPoolDataset, self).__init__(input_dataset, 4297 variant_tensor) 4298 4299 4300def normalize_to_dense(dataset): 4301 """Normalizes non-tensor components in a dataset to dense representations. 4302 4303 This is necessary for dataset transformations that slice along the batch 4304 dimension and are oblivious to non-tensors, e.g. `unbatch`, `rebatch`. 4305 4306 Args: 4307 dataset: Dataset to normalize. 4308 4309 Returns: 4310 A dataset whose sparse and ragged tensors have been normalized to their 4311 dense representations. 4312 """ 4313 4314 # NOTE(mrry): This leads to a somewhat inefficient re-encoding step for all 4315 # non-tensor components. 4316 # 4317 # TODO(mrry): Consider optimizing this if it turns out to be a bottleneck. 4318 if _should_unpack_args(dataset.element_spec): 4319 def normalize(*args): 4320 return structure.to_batched_tensor_list(dataset.element_spec, tuple(args)) 4321 else: 4322 def normalize(arg): 4323 return structure.to_batched_tensor_list(dataset.element_spec, arg) 4324 4325 normalized_dataset = dataset.map(normalize) 4326 4327 # NOTE(mrry): Our `map()` has lost information about the structure of 4328 # non-tensor components, so re-apply the structure of the original dataset. 4329 return _RestructuredDataset(normalized_dataset, dataset.element_spec) 4330 4331 4332class _RestructuredDataset(UnaryDataset): 4333 """An internal helper for changing the structure and shape of a dataset.""" 4334 4335 def __init__(self, dataset, structure): 4336 self._input_dataset = dataset 4337 self._structure = structure 4338 4339 variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access 4340 super(_RestructuredDataset, self).__init__(dataset, variant_tensor) 4341 4342 @property 4343 def element_spec(self): 4344 return self._structure 4345 4346 4347class _UnbatchDataset(UnaryDataset): 4348 """A dataset that splits the elements of its input into multiple elements.""" 4349 4350 def __init__(self, input_dataset): 4351 """See `unbatch()` for more details.""" 4352 flat_shapes = input_dataset._flat_shapes # pylint: disable=protected-access 4353 if any(s.ndims == 0 for s in flat_shapes): 4354 raise ValueError("Cannot unbatch an input with scalar components.") 4355 known_batch_dim = tensor_shape.Dimension(None) 4356 for s in flat_shapes: 4357 try: 4358 known_batch_dim = known_batch_dim.merge_with(s[0]) 4359 except ValueError: 4360 raise ValueError("Cannot unbatch an input whose components have " 4361 "different batch sizes.") 4362 self._input_dataset = input_dataset 4363 self._structure = nest.map_structure( 4364 lambda component_spec: component_spec._unbatch(), # pylint: disable=protected-access 4365 get_structure(input_dataset)) 4366 variant_tensor = ged_ops.unbatch_dataset( 4367 self._input_dataset._variant_tensor, # pylint: disable=protected-access 4368 **self._flat_structure) 4369 super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor) 4370 4371 @property 4372 def element_spec(self): 4373 return self._structure 4374 4375 4376def _collect_resource_inputs(op): 4377 """Collects resource inputs for the given ops (and its variant inputs).""" 4378 4379 def _process(op_queue, seen_ops): 4380 """Processes the next element of the op queue.""" 4381 4382 result = [] 4383 op = op_queue.pop() 4384 if op in seen_ops: 4385 return result 4386 seen_ops.add(op) 4387 for t in op.inputs: 4388 if t.dtype == dtypes.variant: 4389 # Conservatively assume that any variant inputs are datasets. 4390 op_queue.append(t.op) 4391 elif t.dtype == dtypes.resource: 4392 result.append(t) 4393 return result 4394 4395 op_queue = [op] 4396 seen_ops = set() 4397 resource_inputs = [] 4398 while op_queue: 4399 resource_inputs.extend(_process(op_queue, seen_ops)) 4400 4401 return resource_inputs 4402 4403 4404@auto_control_deps.register_acd_resource_resolver 4405def _resource_resolver(op, resource_inputs): 4406 """Updates resource inputs for tf.data ops with indirect dependencies.""" 4407 4408 updated = False 4409 if op.type in [ 4410 "DatasetToSingleElement", "DatasetToTFRecord", "ReduceDataset" 4411 ]: 4412 indirect_resource_inputs = _collect_resource_inputs(op) 4413 for inp in indirect_resource_inputs: 4414 if inp not in resource_inputs: 4415 updated = True 4416 resource_inputs.add(inp) 4417 4418 if op.type in [ 4419 "IteratorGetNext", "IteratorGetNextSync", "IteratorGetNextAsOptional" 4420 ]: 4421 iterator_resource = op.inputs[0] 4422 make_iterator_ops = [ 4423 op for op in iterator_resource.consumers() if op.type == "MakeIterator" 4424 ] 4425 4426 if len(make_iterator_ops) == 1: 4427 indirect_resource_inputs = _collect_resource_inputs(make_iterator_ops[0]) 4428 for inp in indirect_resource_inputs: 4429 if inp not in resource_inputs: 4430 updated = True 4431 resource_inputs.add(inp) 4432 4433 return updated 4434