1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for Datasets.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import functools 22import sys 23import threading 24import warnings 25import weakref 26 27import numpy as np 28import six 29from six.moves import queue as Queue # pylint: disable=redefined-builtin 30 31from tensorflow.core.framework import dataset_options_pb2 32from tensorflow.core.framework import graph_pb2 33from tensorflow.python import tf2 34from tensorflow.python.data.experimental.ops import distribute_options 35from tensorflow.python.data.experimental.ops import optimization_options 36from tensorflow.python.data.experimental.ops import stats_options 37from tensorflow.python.data.experimental.ops import threading_options 38from tensorflow.python.data.ops import iterator_ops 39from tensorflow.python.data.util import convert 40from tensorflow.python.data.util import nest 41from tensorflow.python.data.util import options as options_lib 42from tensorflow.python.data.util import random_seed 43from tensorflow.python.data.util import structure 44from tensorflow.python.data.util import traverse 45from tensorflow.python.eager import context 46from tensorflow.python.eager import def_function 47from tensorflow.python.eager import function as eager_function 48from tensorflow.python.framework import auto_control_deps 49from tensorflow.python.framework import auto_control_deps_utils as acd_utils 50from tensorflow.python.framework import composite_tensor 51from tensorflow.python.framework import constant_op 52from tensorflow.python.framework import dtypes 53from tensorflow.python.framework import function 54from tensorflow.python.framework import ops 55from tensorflow.python.framework import random_seed as core_random_seed 56from tensorflow.python.framework import smart_cond 57from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 58from tensorflow.python.framework import tensor_shape 59from tensorflow.python.framework import tensor_spec 60from tensorflow.python.framework import tensor_util 61from tensorflow.python.framework import type_spec 62from tensorflow.python.ops import array_ops 63from tensorflow.python.ops import control_flow_ops 64from tensorflow.python.ops import gen_dataset_ops 65from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 66from tensorflow.python.ops import gen_io_ops 67from tensorflow.python.ops import math_ops 68from tensorflow.python.ops import script_ops 69from tensorflow.python.ops import string_ops 70from tensorflow.python.ops.ragged import ragged_tensor 71from tensorflow.python.training.tracking import base as tracking_base 72from tensorflow.python.training.tracking import tracking 73from tensorflow.python.util import deprecation 74from tensorflow.python.util import function_utils 75from tensorflow.python.util import lazy_loader 76from tensorflow.python.util import nest as tf_nest 77from tensorflow.python.util.compat import collections_abc 78from tensorflow.python.util.tf_export import tf_export 79 80# Loaded lazily due to a circular dependency (roughly 81# tf.function->wrap_function->dataset->autograph->tf.function). 82# TODO(b/133251390): Use a regular import. 83wrap_function = lazy_loader.LazyLoader( 84 "wrap_function", globals(), 85 "tensorflow.python.eager.wrap_function") 86# TODO(mdan): Create a public API for this. 87autograph_ctx = lazy_loader.LazyLoader( 88 "autograph_ctx", globals(), 89 "tensorflow.python.autograph.core.ag_ctx") 90autograph = lazy_loader.LazyLoader( 91 "autograph", globals(), 92 "tensorflow.python.autograph.impl.api") 93 94ops.NotDifferentiable("ReduceDataset") 95 96# A constant that can be used to enable auto-tuning. 97AUTOTUNE = -1 98tf_export("data.AUTOTUNE").export_constant(__name__, "AUTOTUNE") 99# TODO(b/168128531): Deprecate and remove this symbol. 100tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE") 101 102# Constants representing infinite and unknown cardinalities. 103INFINITE = -1 104UNKNOWN = -2 105tf_export("data.INFINITE_CARDINALITY").export_constant(__name__, "INFINITE") 106tf_export("data.UNKNOWN_CARDINALITY").export_constant(__name__, "UNKNOWN") 107 108 109@tf_export("data.Dataset", v1=[]) 110@six.add_metaclass(abc.ABCMeta) 111class DatasetV2(collections_abc.Iterable, tracking_base.Trackable, 112 composite_tensor.CompositeTensor): 113 """Represents a potentially large set of elements. 114 115 The `tf.data.Dataset` API supports writing descriptive and efficient input 116 pipelines. `Dataset` usage follows a common pattern: 117 118 1. Create a source dataset from your input data. 119 2. Apply dataset transformations to preprocess the data. 120 3. Iterate over the dataset and process the elements. 121 122 Iteration happens in a streaming fashion, so the full dataset does not need to 123 fit into memory. 124 125 Source Datasets: 126 127 The simplest way to create a dataset is to create it from a python `list`: 128 129 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 130 >>> for element in dataset: 131 ... print(element) 132 tf.Tensor(1, shape=(), dtype=int32) 133 tf.Tensor(2, shape=(), dtype=int32) 134 tf.Tensor(3, shape=(), dtype=int32) 135 136 To process lines from files, use `tf.data.TextLineDataset`: 137 138 >>> dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"]) 139 140 To process records written in the `TFRecord` format, use `TFRecordDataset`: 141 142 >>> dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"]) 143 144 To create a dataset of all files matching a pattern, use 145 `tf.data.Dataset.list_files`: 146 147 >>> dataset = tf.data.Dataset.list_files("/path/*.txt") # doctest: +SKIP 148 149 See `tf.data.FixedLengthRecordDataset` and `tf.data.Dataset.from_generator` 150 for more ways to create datasets. 151 152 Transformations: 153 154 Once you have a dataset, you can apply transformations to prepare the data for 155 your model: 156 157 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 158 >>> dataset = dataset.map(lambda x: x*2) 159 >>> list(dataset.as_numpy_iterator()) 160 [2, 4, 6] 161 162 Common Terms: 163 164 **Element**: A single output from calling `next()` on a dataset iterator. 165 Elements may be nested structures containing multiple components. For 166 example, the element `(1, (3, "apple"))` has one tuple nested in another 167 tuple. The components are `1`, `3`, and `"apple"`. 168 169 **Component**: The leaf in the nested structure of an element. 170 171 Supported types: 172 173 Elements can be nested structures of tuples, named tuples, and dictionaries. 174 Note that Python lists are *not* treated as nested structures of components. 175 Instead, lists are converted to tensors and treated as components. For 176 example, the element `(1, [1, 2, 3])` has only two components; the tensor `1` 177 and the tensor `[1, 2, 3]`. Element components can be of any type 178 representable by `tf.TypeSpec`, including `tf.Tensor`, `tf.data.Dataset`, 179 `tf.sparse.SparseTensor`, `tf.RaggedTensor`, and `tf.TensorArray`. 180 181 >>> a = 1 # Integer element 182 >>> b = 2.0 # Float element 183 >>> c = (1, 2) # Tuple element with 2 components 184 >>> d = {"a": (2, 2), "b": 3} # Dict element with 3 components 185 >>> Point = collections.namedtuple("Point", ["x", "y"]) # doctest: +SKIP 186 >>> e = Point(1, 2) # Named tuple # doctest: +SKIP 187 >>> f = tf.data.Dataset.range(10) # Dataset element 188 189 """ 190 191 def __init__(self, variant_tensor): 192 """Creates a DatasetV2 object. 193 194 This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not 195 take anything in its constructor whereas in the DatasetV2, we expect 196 subclasses to create a variant_tensor and pass it in to the super() call. 197 198 Args: 199 variant_tensor: A DT_VARIANT tensor that represents the dataset. 200 """ 201 self._variant_tensor_attr = variant_tensor 202 weak_self = weakref.proxy(self) 203 self._variant_tracker = self._track_trackable( 204 _VariantTracker( 205 self._variant_tensor, 206 # _trace_variant_creation only works when executing eagerly, so we 207 # don't want to run it immediately. We also want the _VariantTracker 208 # to have a weak reference to the Dataset to avoid creating 209 # reference cycles and making work for the garbage collector. 210 lambda: weak_self._trace_variant_creation()()), # pylint: disable=unnecessary-lambda,protected-access 211 name="_variant_tracker") 212 self._graph_attr = ops.get_default_graph() 213 214 # Initialize the options for this dataset and its inputs. 215 self._options_attr = Options() 216 for input_dataset in self._inputs(): 217 input_options = input_dataset.options() 218 if input_options is not None: 219 self._options_attr = self._options_attr.merge(input_options) 220 221 @property 222 def _variant_tensor(self): 223 return self._variant_tensor_attr 224 225 @_variant_tensor.setter 226 def _variant_tensor(self, _): 227 raise ValueError("The _variant_tensor property is read-only") 228 229 @deprecation.deprecated_args(None, "Use external_state_policy instead", 230 "allow_stateful") 231 def _as_serialized_graph( 232 self, 233 allow_stateful=None, 234 strip_device_assignment=None, 235 external_state_policy=distribute_options.ExternalStatePolicy.WARN): 236 """Produces serialized graph representation of the dataset. 237 238 Args: 239 allow_stateful: If true, we allow stateful ops to be present in the graph 240 def. In that case, the state in these ops would be thrown away. 241 strip_device_assignment: If true, non-local (i.e. job and task) device 242 assignment is stripped from ops in the serialized graph. 243 external_state_policy: The ExternalStatePolicy enum that determines how we 244 handle input pipelines that depend on external state. By default, its 245 set to WARN. 246 247 Returns: 248 A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a 249 serialized graph. 250 """ 251 if external_state_policy: 252 policy = external_state_policy.value 253 return gen_dataset_ops.dataset_to_graph_v2( 254 self._variant_tensor, 255 external_state_policy=policy, 256 strip_device_assignment=strip_device_assignment) 257 if strip_device_assignment: 258 return gen_dataset_ops.dataset_to_graph( 259 self._variant_tensor, 260 allow_stateful=allow_stateful, 261 strip_device_assignment=strip_device_assignment) 262 return gen_dataset_ops.dataset_to_graph( 263 self._variant_tensor, allow_stateful=allow_stateful) 264 265 def _trace_variant_creation(self): 266 """Traces a function which outputs a variant `tf.Tensor` for this dataset. 267 268 Note that creating this function involves evaluating an op, and is currently 269 only supported when executing eagerly. 270 271 Returns: 272 A zero-argument `ConcreteFunction` which outputs a variant `tf.Tensor`. 273 """ 274 variant = self._variant_tensor 275 if not isinstance(variant, ops.EagerTensor): 276 raise NotImplementedError( 277 "Can only export Datasets which were created executing eagerly. " 278 "Please file a feature request if this is important to you.") 279 with context.eager_mode(), ops.device("CPU"): 280 # pylint: disable=protected-access 281 graph_def = graph_pb2.GraphDef().FromString( 282 self._as_serialized_graph(external_state_policy=distribute_options 283 .ExternalStatePolicy.FAIL).numpy()) 284 output_node_name = None 285 for node in graph_def.node: 286 if node.op == "_Retval": 287 if output_node_name is not None: 288 raise AssertionError( 289 "Found multiple return values from the dataset's graph, expected " 290 "only one.") 291 output_node_name, = node.input 292 if output_node_name is None: 293 raise AssertionError("Could not find the dataset's output node.") 294 # Add functions used in this Dataset to the function's graph, since they 295 # need to follow it around (and for example be added to a SavedModel which 296 # references the dataset). 297 variant_function = wrap_function.function_from_graph_def( 298 graph_def, inputs=[], outputs=output_node_name + ":0") 299 for used_function in self._functions(): 300 used_function.function.add_to_graph(variant_function.graph) 301 return variant_function 302 303 @abc.abstractmethod 304 def _inputs(self): 305 """Returns a list of the input datasets of the dataset.""" 306 307 raise NotImplementedError("Dataset._inputs") 308 309 @property 310 def _graph(self): 311 return self._graph_attr 312 313 @_graph.setter 314 def _graph(self, _): 315 raise ValueError("The _graph property is read-only") 316 317 def _has_captured_ref(self): 318 """Whether this dataset uses a function that captures ref variables. 319 320 Returns: 321 A boolean, which if true indicates that the dataset or one of its inputs 322 uses a function that captures ref variables. 323 """ 324 if context.executing_eagerly(): 325 # RefVariables are not supported in eager mode 326 return False 327 328 def is_tensor_or_parent_ref(tensor): 329 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access 330 return True 331 # If the captured tensor is an eager tensor, we cannot trace its inputs. 332 if isinstance(tensor, ops._EagerTensorBase): # pylint: disable=protected-access 333 return False 334 return any(is_tensor_or_parent_ref(x) for x in tensor.op.inputs) 335 336 for fn in self._functions(): 337 if any(is_tensor_or_parent_ref(t) for t in fn.function.captured_inputs): 338 return True 339 340 return any( 341 [input_dataset._has_captured_ref() for input_dataset in self._inputs()]) # pylint: disable=protected-access 342 343 # TODO(jsimsa): Change this to be the transitive closure of functions used 344 # by this dataset and its inputs. 345 def _functions(self): 346 """Returns a list of functions associated with this dataset. 347 348 Returns: 349 A list of `StructuredFunctionWrapper` objects. 350 """ 351 return [] 352 353 def options(self): 354 """Returns the options for this dataset and its inputs. 355 356 Returns: 357 A `tf.data.Options` object representing the dataset options. 358 """ 359 return self._options_attr 360 361 def _apply_options(self): 362 """Apply options, such as optimization configuration, to the dataset.""" 363 364 dataset = self 365 options = self.options() 366 367 # (1) Apply threading options 368 if options.experimental_threading is not None: 369 t_options = options.experimental_threading 370 if t_options.max_intra_op_parallelism is not None: 371 dataset = _MaxIntraOpParallelismDataset( 372 dataset, t_options.max_intra_op_parallelism) 373 if t_options.private_threadpool_size is not None: 374 dataset = _PrivateThreadPoolDataset(dataset, 375 t_options.private_threadpool_size) 376 377 # (2) Apply autotune options 378 autotune, algorithm, cpu_budget, ram_budget = options._autotune_settings() # pylint: disable=protected-access 379 if autotune: 380 dataset = _ModelDataset(dataset, algorithm, cpu_budget, ram_budget) 381 382 # (3) Apply graph rewrite options 383 # pylint: disable=protected-access 384 graph_rewrites = options._graph_rewrites() 385 graph_rewrite_configs = options._graph_rewrite_configs(autotune) 386 # pylint: enable=protected-access 387 if self._has_captured_ref(): 388 if graph_rewrites.enabled or graph_rewrites.default: 389 warnings.warn( 390 "tf.data graph rewrites are not compatible with tf.Variable. " 391 "The following rewrites will be disabled: %s. To enable " 392 "rewrites, use resource variables instead by calling " 393 "`tf.enable_resource_variables()` at the start of the program." % 394 ", ".join(graph_rewrites.enabled + graph_rewrites.default)) 395 elif (graph_rewrites.enabled or graph_rewrites.default or 396 (options.experimental_optimization.apply_default_optimizations # pylint: disable=g-bool-id-comparison 397 is not False)): 398 dataset = _OptimizeDataset(dataset, graph_rewrites.enabled, 399 graph_rewrites.disabled, 400 graph_rewrites.default, graph_rewrite_configs) 401 402 # (4) Apply stats aggregator options 403 if options.experimental_stats and options.experimental_stats.aggregator: # pylint: disable=line-too-long 404 dataset = _SetStatsAggregatorDataset( # pylint: disable=protected-access 405 dataset, options.experimental_stats.aggregator, 406 options.experimental_stats.prefix, 407 options.experimental_stats.counter_prefix) 408 return dataset 409 410 def __iter__(self): 411 """Creates an iterator for elements of this dataset. 412 413 The returned iterator implements the Python Iterator protocol. 414 415 Returns: 416 An `tf.data.Iterator` for the elements of this dataset. 417 418 Raises: 419 RuntimeError: If not inside of tf.function and not executing eagerly. 420 """ 421 if context.executing_eagerly() or ops.inside_function(): 422 with ops.colocate_with(self._variant_tensor): 423 return iterator_ops.OwnedIterator(self) 424 else: 425 raise RuntimeError("__iter__() is only supported inside of tf.function " 426 "or when eager execution is enabled.") 427 428 def __bool__(self): 429 return True # Required as __len__ is defined 430 431 __nonzero__ = __bool__ # Python 2 backward compatibility 432 433 def __len__(self): 434 """Returns the length of the dataset if it is known and finite. 435 436 This method requires that you are running in eager mode, and that the 437 length of the dataset is known and non-infinite. When the length may be 438 unknown or infinite, or if you are running in graph mode, use 439 `tf.data.Dataset.cardinality` instead. 440 441 Returns: 442 An integer representing the length of the dataset. 443 444 Raises: 445 RuntimeError: If the dataset length is unknown or infinite, or if eager 446 execution is not enabled. 447 """ 448 if not context.executing_eagerly(): 449 raise TypeError("__len__() is not supported while tracing functions. " 450 "Use `tf.data.Dataset.cardinality` instead.") 451 length = self.cardinality() 452 if length.numpy() == INFINITE: 453 raise TypeError("dataset length is infinite.") 454 if length.numpy() == UNKNOWN: 455 raise TypeError("dataset length is unknown.") 456 return length 457 458 @abc.abstractproperty 459 def element_spec(self): 460 """The type specification of an element of this dataset. 461 462 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 463 >>> dataset.element_spec 464 TensorSpec(shape=(), dtype=tf.int32, name=None) 465 466 Returns: 467 A nested structure of `tf.TypeSpec` objects matching the structure of an 468 element of this dataset and specifying the type of individual components. 469 """ 470 raise NotImplementedError("Dataset.element_spec") 471 472 def __repr__(self): 473 output_shapes = nest.map_structure(str, get_legacy_output_shapes(self)) 474 output_shapes = str(output_shapes).replace("'", "") 475 output_types = nest.map_structure(repr, get_legacy_output_types(self)) 476 output_types = str(output_types).replace("'", "") 477 return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes, 478 output_types)) 479 480 def as_numpy_iterator(self): 481 """Returns an iterator which converts all elements of the dataset to numpy. 482 483 Use `as_numpy_iterator` to inspect the content of your dataset. To see 484 element shapes and types, print dataset elements directly instead of using 485 `as_numpy_iterator`. 486 487 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 488 >>> for element in dataset: 489 ... print(element) 490 tf.Tensor(1, shape=(), dtype=int32) 491 tf.Tensor(2, shape=(), dtype=int32) 492 tf.Tensor(3, shape=(), dtype=int32) 493 494 This method requires that you are running in eager mode and the dataset's 495 element_spec contains only `TensorSpec` components. 496 497 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 498 >>> for element in dataset.as_numpy_iterator(): 499 ... print(element) 500 1 501 2 502 3 503 504 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 505 >>> print(list(dataset.as_numpy_iterator())) 506 [1, 2, 3] 507 508 `as_numpy_iterator()` will preserve the nested structure of dataset 509 elements. 510 511 >>> dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]), 512 ... 'b': [5, 6]}) 513 >>> list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5}, 514 ... {'a': (2, 4), 'b': 6}] 515 True 516 517 Returns: 518 An iterable over the elements of the dataset, with their tensors converted 519 to numpy arrays. 520 521 Raises: 522 TypeError: if an element contains a non-`Tensor` value. 523 RuntimeError: if eager execution is not enabled. 524 """ 525 if not context.executing_eagerly(): 526 raise RuntimeError("as_numpy_iterator() is not supported while tracing " 527 "functions") 528 for component_spec in nest.flatten(self.element_spec): 529 if not isinstance( 530 component_spec, 531 (tensor_spec.TensorSpec, ragged_tensor.RaggedTensorSpec)): 532 raise TypeError( 533 "Dataset.as_numpy_iterator() does not support datasets containing " 534 + str(component_spec.value_type)) 535 536 return _NumpyIterator(self) 537 538 @property 539 def _flat_shapes(self): 540 """Returns a list `tf.TensorShapes`s for the element tensor representation. 541 542 Returns: 543 A list `tf.TensorShapes`s for the element tensor representation. 544 """ 545 return structure.get_flat_tensor_shapes(self.element_spec) 546 547 @property 548 def _flat_types(self): 549 """Returns a list `tf.DType`s for the element tensor representation. 550 551 Returns: 552 A list `tf.DType`s for the element tensor representation. 553 """ 554 return structure.get_flat_tensor_types(self.element_spec) 555 556 @property 557 def _flat_structure(self): 558 """Helper for setting `output_shapes` and `output_types` attrs of an op. 559 560 Most dataset op constructors expect `output_shapes` and `output_types` 561 arguments that represent the flattened structure of an element. This helper 562 function generates these attrs as a keyword argument dictionary, allowing 563 `Dataset._variant_tensor` implementations to pass `**self._flat_structure` 564 to the op constructor. 565 566 Returns: 567 A dictionary of keyword arguments that can be passed to a dataset op 568 constructor. 569 """ 570 return { 571 "output_shapes": self._flat_shapes, 572 "output_types": self._flat_types, 573 } 574 575 @property 576 def _type_spec(self): 577 return DatasetSpec(self.element_spec) 578 579 @staticmethod 580 def from_tensors(tensors): 581 """Creates a `Dataset` with a single element, comprising the given tensors. 582 583 `from_tensors` produces a dataset containing only a single element. To slice 584 the input tensor into multiple elements, use `from_tensor_slices` instead. 585 586 >>> dataset = tf.data.Dataset.from_tensors([1, 2, 3]) 587 >>> list(dataset.as_numpy_iterator()) 588 [array([1, 2, 3], dtype=int32)] 589 >>> dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A')) 590 >>> list(dataset.as_numpy_iterator()) 591 [(array([1, 2, 3], dtype=int32), b'A')] 592 593 >>> # You can use `from_tensors` to produce a dataset which repeats 594 >>> # the same example many times. 595 >>> example = tf.constant([1,2,3]) 596 >>> dataset = tf.data.Dataset.from_tensors(example).repeat(2) 597 >>> list(dataset.as_numpy_iterator()) 598 [array([1, 2, 3], dtype=int32), array([1, 2, 3], dtype=int32)] 599 600 Note that if `tensors` contains a NumPy array, and eager execution is not 601 enabled, the values will be embedded in the graph as one or more 602 `tf.constant` operations. For large datasets (> 1 GB), this can waste 603 memory and run into byte limits of graph serialization. If `tensors` 604 contains one or more large NumPy arrays, consider the alternative described 605 in [this 606 guide](https://tensorflow.org/guide/data#consuming_numpy_arrays). 607 608 Args: 609 tensors: A dataset element. 610 611 Returns: 612 Dataset: A `Dataset`. 613 """ 614 return TensorDataset(tensors) 615 616 @staticmethod 617 def from_tensor_slices(tensors): 618 """Creates a `Dataset` whose elements are slices of the given tensors. 619 620 The given tensors are sliced along their first dimension. This operation 621 preserves the structure of the input tensors, removing the first dimension 622 of each tensor and using it as the dataset dimension. All input tensors 623 must have the same size in their first dimensions. 624 625 >>> # Slicing a 1D tensor produces scalar tensor elements. 626 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 627 >>> list(dataset.as_numpy_iterator()) 628 [1, 2, 3] 629 630 >>> # Slicing a 2D tensor produces 1D tensor elements. 631 >>> dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]]) 632 >>> list(dataset.as_numpy_iterator()) 633 [array([1, 2], dtype=int32), array([3, 4], dtype=int32)] 634 635 >>> # Slicing a tuple of 1D tensors produces tuple elements containing 636 >>> # scalar tensors. 637 >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6])) 638 >>> list(dataset.as_numpy_iterator()) 639 [(1, 3, 5), (2, 4, 6)] 640 641 >>> # Dictionary structure is also preserved. 642 >>> dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]}) 643 >>> list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3}, 644 ... {'a': 2, 'b': 4}] 645 True 646 647 >>> # Two tensors can be combined into one Dataset object. 648 >>> features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor 649 >>> labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor 650 >>> dataset = Dataset.from_tensor_slices((features, labels)) 651 >>> # Both the features and the labels tensors can be converted 652 >>> # to a Dataset object separately and combined after. 653 >>> features_dataset = Dataset.from_tensor_slices(features) 654 >>> labels_dataset = Dataset.from_tensor_slices(labels) 655 >>> dataset = Dataset.zip((features_dataset, labels_dataset)) 656 >>> # A batched feature and label set can be converted to a Dataset 657 >>> # in similar fashion. 658 >>> batched_features = tf.constant([[[1, 3], [2, 3]], 659 ... [[2, 1], [1, 2]], 660 ... [[3, 3], [3, 2]]], shape=(3, 2, 2)) 661 >>> batched_labels = tf.constant([['A', 'A'], 662 ... ['B', 'B'], 663 ... ['A', 'B']], shape=(3, 2, 1)) 664 >>> dataset = Dataset.from_tensor_slices((batched_features, batched_labels)) 665 >>> for element in dataset.as_numpy_iterator(): 666 ... print(element) 667 (array([[1, 3], 668 [2, 3]], dtype=int32), array([[b'A'], 669 [b'A']], dtype=object)) 670 (array([[2, 1], 671 [1, 2]], dtype=int32), array([[b'B'], 672 [b'B']], dtype=object)) 673 (array([[3, 3], 674 [3, 2]], dtype=int32), array([[b'A'], 675 [b'B']], dtype=object)) 676 677 Note that if `tensors` contains a NumPy array, and eager execution is not 678 enabled, the values will be embedded in the graph as one or more 679 `tf.constant` operations. For large datasets (> 1 GB), this can waste 680 memory and run into byte limits of graph serialization. If `tensors` 681 contains one or more large NumPy arrays, consider the alternative described 682 in [this guide]( 683 https://tensorflow.org/guide/data#consuming_numpy_arrays). 684 685 Args: 686 tensors: A dataset element, with each component having the same size in 687 the first dimension. 688 689 Returns: 690 Dataset: A `Dataset`. 691 """ 692 return TensorSliceDataset(tensors) 693 694 class _GeneratorState(object): 695 """Stores outstanding iterators created from a Python generator. 696 697 This class keeps track of potentially multiple iterators that may have 698 been created from a generator, e.g. in the case that the dataset is 699 repeated, or nested within a parallel computation. 700 """ 701 702 def __init__(self, generator): 703 self._generator = generator 704 self._lock = threading.Lock() 705 self._next_id = 0 # GUARDED_BY(self._lock) 706 self._args = {} 707 self._iterators = {} 708 709 def get_next_id(self, *args): 710 with self._lock: 711 ret = self._next_id 712 self._next_id += 1 713 self._args[ret] = args 714 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit 715 # casting in `py_func()` will create an array of `np.int32` on Windows, 716 # leading to a runtime error. 717 return np.array(ret, dtype=np.int64) 718 719 def get_iterator(self, iterator_id): 720 try: 721 return self._iterators[iterator_id] 722 except KeyError: 723 iterator = iter(self._generator(*self._args.pop(iterator_id))) 724 self._iterators[iterator_id] = iterator 725 return iterator 726 727 def iterator_completed(self, iterator_id): 728 del self._iterators[iterator_id] 729 730 @staticmethod 731 @deprecation.deprecated_args(None, "Use output_signature instead", 732 "output_types", "output_shapes") 733 def from_generator(generator, 734 output_types=None, 735 output_shapes=None, 736 args=None, 737 output_signature=None): 738 """Creates a `Dataset` whose elements are generated by `generator`. 739 740 The `generator` argument must be a callable object that returns 741 an object that supports the `iter()` protocol (e.g. a generator function). 742 743 The elements generated by `generator` must be compatible with either the 744 given `output_signature` argument or with the given `output_types` and 745 (optionally) `output_shapes` arguments, whichever was specified. 746 747 The recommended way to call `from_generator` is to use the 748 `output_signature` argument. In this case the output will be assumed to 749 consist of objects with the classes, shapes and types defined by 750 `tf.TypeSpec` objects from `output_signature` argument: 751 752 >>> def gen(): 753 ... ragged_tensor = tf.ragged.constant([[1, 2], [3]]) 754 ... yield 42, ragged_tensor 755 >>> 756 >>> dataset = tf.data.Dataset.from_generator( 757 ... gen, 758 ... output_signature=( 759 ... tf.TensorSpec(shape=(), dtype=tf.int32), 760 ... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32))) 761 >>> 762 >>> list(dataset.take(1)) 763 [(<tf.Tensor: shape=(), dtype=int32, numpy=42>, 764 <tf.RaggedTensor [[1, 2], [3]]>)] 765 766 There is also a deprecated way to call `from_generator` by either with 767 `output_types` argument alone or together with `output_shapes` argument. 768 In this case the output of the function will be assumed to consist of 769 `tf.Tensor` objects with the types defined by `output_types` and with the 770 shapes which are either unknown or defined by `output_shapes`. 771 772 Note: The current implementation of `Dataset.from_generator()` uses 773 `tf.numpy_function` and inherits the same constraints. In particular, it 774 requires the dataset and iterator related operations to be placed 775 on a device in the same process as the Python program that called 776 `Dataset.from_generator()`. The body of `generator` will not be 777 serialized in a `GraphDef`, and you should not use this method if you 778 need to serialize your model and restore it in a different environment. 779 780 Note: If `generator` depends on mutable global variables or other external 781 state, be aware that the runtime may invoke `generator` multiple times 782 (in order to support repeating the `Dataset`) and at any time 783 between the call to `Dataset.from_generator()` and the production of the 784 first element from the generator. Mutating global variables or external 785 state can cause undefined behavior, and we recommend that you explicitly 786 cache any external state in `generator` before calling 787 `Dataset.from_generator()`. 788 789 Args: 790 generator: A callable object that returns an object that supports the 791 `iter()` protocol. If `args` is not specified, `generator` must take no 792 arguments; otherwise it must take as many arguments as there are values 793 in `args`. 794 output_types: (Optional.) A nested structure of `tf.DType` objects 795 corresponding to each component of an element yielded by `generator`. 796 output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects 797 corresponding to each component of an element yielded by `generator`. 798 args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated 799 and passed to `generator` as NumPy-array arguments. 800 output_signature: (Optional.) A nested structure of `tf.TypeSpec` objects 801 corresponding to each component of an element yielded by `generator`. 802 803 Returns: 804 Dataset: A `Dataset`. 805 """ 806 if not callable(generator): 807 raise TypeError("`generator` must be callable.") 808 809 if output_signature is not None: 810 if output_types is not None: 811 raise TypeError("`output_types` can not be used together with " 812 "`output_signature`") 813 if output_shapes is not None: 814 raise TypeError("`output_shapes` can not be used together with " 815 "`output_signature`") 816 if not all( 817 isinstance(_, type_spec.TypeSpec) 818 for _ in nest.flatten(output_signature)): 819 raise TypeError("All the elements of `output_signature` must be " 820 "`tf.TypeSpec` objects.") 821 else: 822 if output_types is None: 823 raise TypeError("Either `output_signature` or `output_types` must " 824 "be specified") 825 826 if output_signature is None: 827 if output_shapes is None: 828 output_shapes = nest.map_structure( 829 lambda _: tensor_shape.TensorShape(None), output_types) 830 else: 831 output_shapes = nest.map_structure_up_to(output_types, 832 tensor_shape.as_shape, 833 output_shapes) 834 output_signature = nest.map_structure_up_to(output_types, 835 tensor_spec.TensorSpec, 836 output_shapes, output_types) 837 if all([ 838 isinstance(x, tensor_spec.TensorSpec) 839 for x in nest.flatten(output_signature) 840 ]): 841 output_types = nest.pack_sequence_as( 842 output_signature, [x.dtype for x in nest.flatten(output_signature)]) 843 output_shapes = nest.pack_sequence_as( 844 output_signature, [x.shape for x in nest.flatten(output_signature)]) 845 846 if args is None: 847 args = () 848 else: 849 args = tuple(ops.convert_n_to_tensor(args, name="args")) 850 851 generator_state = DatasetV2._GeneratorState(generator) 852 853 def get_iterator_id_fn(unused_dummy): 854 """Creates a unique `iterator_id` for each pass over the dataset. 855 856 The returned `iterator_id` disambiguates between multiple concurrently 857 existing iterators. 858 859 Args: 860 unused_dummy: Ignored value. 861 862 Returns: 863 A `tf.int64` tensor whose value uniquely identifies an iterator in 864 `generator_state`. 865 """ 866 return script_ops.numpy_function(generator_state.get_next_id, args, 867 dtypes.int64) 868 869 def generator_next_fn(iterator_id_t): 870 """Generates the next element from iterator with ID `iterator_id_t`. 871 872 We map this function across an infinite repetition of the 873 `iterator_id_t`, and raise `StopIteration` to terminate the iteration. 874 875 Args: 876 iterator_id_t: A `tf.int64` tensor whose value uniquely identifies the 877 iterator in `generator_state` from which to generate an element. 878 879 Returns: 880 The next element to generate from the iterator. 881 """ 882 if output_types and output_shapes: 883 flattened_types = [ 884 dtypes.as_dtype(dt) for dt in nest.flatten(output_types) 885 ] 886 flattened_shapes = nest.flatten(output_shapes) 887 888 def generator_py_func(iterator_id): 889 """A `py_func` that will be called to invoke the iterator.""" 890 # `next()` raises `StopIteration` when there are no more 891 # elements remaining to be generated. 892 values = next(generator_state.get_iterator(iterator_id)) 893 894 # Use the same _convert function from the py_func() implementation to 895 # convert the returned values to arrays early, so that we can inspect 896 # their values. 897 try: 898 flattened_values = nest.flatten_up_to(output_types, values) 899 except (TypeError, ValueError): 900 six.reraise( 901 TypeError, 902 TypeError( 903 "`generator` yielded an element that did not match the " 904 "expected structure. The expected structure was %s, but " 905 "the yielded element was %s." % (output_types, values)), 906 sys.exc_info()[2]) 907 ret_arrays = [] 908 for ret, dtype in zip(flattened_values, flattened_types): 909 try: 910 ret_arrays.append( 911 script_ops.FuncRegistry._convert( # pylint: disable=protected-access 912 ret, 913 dtype=dtype.as_numpy_dtype)) 914 except (TypeError, ValueError): 915 six.reraise( 916 TypeError, 917 TypeError( 918 "`generator` yielded an element that could not be " 919 "converted to the expected type. The expected type was " 920 "%s, but the yielded element was %s." % 921 (dtype.name, ret)), 922 sys.exc_info()[2]) 923 924 # Additional type and shape checking to ensure that the components of 925 # the generated element match the `output_types` and `output_shapes` 926 # arguments. 927 for (ret_array, expected_dtype, 928 expected_shape) in zip(ret_arrays, flattened_types, 929 flattened_shapes): 930 if ret_array.dtype != expected_dtype.as_numpy_dtype: 931 raise TypeError( 932 "`generator` yielded an element of type %s where an element " 933 "of type %s was expected." % 934 (ret_array.dtype, expected_dtype.as_numpy_dtype)) 935 if not expected_shape.is_compatible_with(ret_array.shape): 936 raise ValueError( 937 "`generator` yielded an element of shape %s where an element " 938 "of shape %s was expected." % 939 (ret_array.shape, expected_shape)) 940 941 return ret_arrays 942 943 flat_values = script_ops.numpy_function(generator_py_func, 944 [iterator_id_t], 945 flattened_types) 946 947 # The `py_func()` op drops the inferred shapes, so we add them back in 948 # here. 949 if output_shapes is not None: 950 for ret_t, shape in zip(flat_values, flattened_shapes): 951 ret_t.set_shape(shape) 952 953 return nest.pack_sequence_as(output_types, flat_values) 954 else: 955 flat_output_types = structure.get_flat_tensor_types(output_signature) 956 957 def generator_py_func(iterator_id): 958 """A `py_func` that will be called to invoke the iterator.""" 959 # `next()` raises `StopIteration` when there are no more 960 # elements remaining to be generated. 961 values = next(generator_state.get_iterator(iterator_id.numpy())) 962 963 try: 964 values = structure.normalize_element(values, output_signature) 965 except (TypeError, ValueError): 966 six.reraise( 967 TypeError, 968 TypeError( 969 "`generator` yielded an element that did not match the " 970 "expected structure. The expected structure was %s, but " 971 "the yielded element was %s." % (output_signature, values)), 972 sys.exc_info()[2]) 973 974 values_spec = structure.type_spec_from_value(values) 975 976 if not structure.are_compatible(values_spec, output_signature): 977 raise TypeError( 978 "`generator` yielded an element of %s where an element " 979 "of %s was expected." % (values_spec, output_signature)) 980 981 return structure.to_tensor_list(output_signature, values) 982 983 return script_ops._eager_py_func( # pylint: disable=protected-access 984 generator_py_func, 985 inp=[iterator_id_t], 986 Tout=flat_output_types, 987 use_tape_cache=False) 988 989 def finalize_fn(iterator_id_t): 990 """Releases host-side state for the iterator with ID `iterator_id_t`.""" 991 992 def finalize_py_func(iterator_id): 993 generator_state.iterator_completed(iterator_id) 994 # We return a dummy value so that the `finalize_fn` has a valid 995 # signature. 996 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit 997 # casting in `py_func()` will create an array of `np.int32` on Windows, 998 # leading to a runtime error. 999 return np.array(0, dtype=np.int64) 1000 1001 return script_ops.numpy_function(finalize_py_func, [iterator_id_t], 1002 dtypes.int64) 1003 1004 # This function associates each traversal of `generator` with a unique 1005 # iterator ID. 1006 def flat_map_fn(dummy_arg): 1007 # The `get_iterator_id_fn` gets a unique ID for the current instance of 1008 # of the generator. 1009 # The `generator_next_fn` gets the next element from the iterator with the 1010 # given ID, and raises StopIteration when that iterator contains no 1011 # more elements. 1012 return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn, 1013 finalize_fn, output_signature) 1014 1015 # A single-element dataset that, each time it is evaluated, contains a 1016 # freshly-generated and unique (for the returned dataset) int64 1017 # ID that will be used to identify the appropriate Python state, which 1018 # is encapsulated in `generator_state`, and captured in 1019 # `get_iterator_id_map_fn`. 1020 dummy = 0 1021 id_dataset = Dataset.from_tensors(dummy) 1022 1023 # A dataset that contains all of the elements generated by a 1024 # single iterator created from `generator`, identified by the 1025 # iterator ID contained in `id_dataset`. Lifting the iteration 1026 # into a flat_map here enables multiple repetitions and/or nested 1027 # versions of the returned dataset to be created, because it forces 1028 # the generation of a new ID for each version. 1029 return id_dataset.flat_map(flat_map_fn) 1030 1031 @staticmethod 1032 def range(*args, **kwargs): 1033 """Creates a `Dataset` of a step-separated range of values. 1034 1035 >>> list(Dataset.range(5).as_numpy_iterator()) 1036 [0, 1, 2, 3, 4] 1037 >>> list(Dataset.range(2, 5).as_numpy_iterator()) 1038 [2, 3, 4] 1039 >>> list(Dataset.range(1, 5, 2).as_numpy_iterator()) 1040 [1, 3] 1041 >>> list(Dataset.range(1, 5, -2).as_numpy_iterator()) 1042 [] 1043 >>> list(Dataset.range(5, 1).as_numpy_iterator()) 1044 [] 1045 >>> list(Dataset.range(5, 1, -2).as_numpy_iterator()) 1046 [5, 3] 1047 >>> list(Dataset.range(2, 5, output_type=tf.int32).as_numpy_iterator()) 1048 [2, 3, 4] 1049 >>> list(Dataset.range(1, 5, 2, output_type=tf.float32).as_numpy_iterator()) 1050 [1.0, 3.0] 1051 1052 Args: 1053 *args: follows the same semantics as python's xrange. 1054 len(args) == 1 -> start = 0, stop = args[0], step = 1. 1055 len(args) == 2 -> start = args[0], stop = args[1], step = 1. 1056 len(args) == 3 -> start = args[0], stop = args[1], step = args[2]. 1057 **kwargs: 1058 - output_type: Its expected dtype. (Optional, default: `tf.int64`). 1059 1060 Returns: 1061 Dataset: A `RangeDataset`. 1062 1063 Raises: 1064 ValueError: if len(args) == 0. 1065 """ 1066 return RangeDataset(*args, **kwargs) 1067 1068 @staticmethod 1069 def zip(datasets): 1070 """Creates a `Dataset` by zipping together the given datasets. 1071 1072 This method has similar semantics to the built-in `zip()` function 1073 in Python, with the main difference being that the `datasets` 1074 argument can be an arbitrary nested structure of `Dataset` objects. 1075 1076 >>> # The nested structure of the `datasets` argument determines the 1077 >>> # structure of elements in the resulting dataset. 1078 >>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ] 1079 >>> b = tf.data.Dataset.range(4, 7) # ==> [ 4, 5, 6 ] 1080 >>> ds = tf.data.Dataset.zip((a, b)) 1081 >>> list(ds.as_numpy_iterator()) 1082 [(1, 4), (2, 5), (3, 6)] 1083 >>> ds = tf.data.Dataset.zip((b, a)) 1084 >>> list(ds.as_numpy_iterator()) 1085 [(4, 1), (5, 2), (6, 3)] 1086 >>> 1087 >>> # The `datasets` argument may contain an arbitrary number of datasets. 1088 >>> c = tf.data.Dataset.range(7, 13).batch(2) # ==> [ [7, 8], 1089 ... # [9, 10], 1090 ... # [11, 12] ] 1091 >>> ds = tf.data.Dataset.zip((a, b, c)) 1092 >>> for element in ds.as_numpy_iterator(): 1093 ... print(element) 1094 (1, 4, array([7, 8])) 1095 (2, 5, array([ 9, 10])) 1096 (3, 6, array([11, 12])) 1097 >>> 1098 >>> # The number of elements in the resulting dataset is the same as 1099 >>> # the size of the smallest dataset in `datasets`. 1100 >>> d = tf.data.Dataset.range(13, 15) # ==> [ 13, 14 ] 1101 >>> ds = tf.data.Dataset.zip((a, d)) 1102 >>> list(ds.as_numpy_iterator()) 1103 [(1, 13), (2, 14)] 1104 1105 Args: 1106 datasets: A nested structure of datasets. 1107 1108 Returns: 1109 Dataset: A `Dataset`. 1110 """ 1111 return ZipDataset(datasets) 1112 1113 def concatenate(self, dataset): 1114 """Creates a `Dataset` by concatenating the given dataset with this dataset. 1115 1116 >>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ] 1117 >>> b = tf.data.Dataset.range(4, 8) # ==> [ 4, 5, 6, 7 ] 1118 >>> ds = a.concatenate(b) 1119 >>> list(ds.as_numpy_iterator()) 1120 [1, 2, 3, 4, 5, 6, 7] 1121 >>> # The input dataset and dataset to be concatenated should have the same 1122 >>> # nested structures and output types. 1123 >>> c = tf.data.Dataset.zip((a, b)) 1124 >>> a.concatenate(c) 1125 Traceback (most recent call last): 1126 TypeError: Two datasets to concatenate have different types 1127 <dtype: 'int64'> and (tf.int64, tf.int64) 1128 >>> d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"]) 1129 >>> a.concatenate(d) 1130 Traceback (most recent call last): 1131 TypeError: Two datasets to concatenate have different types 1132 <dtype: 'int64'> and <dtype: 'string'> 1133 1134 Args: 1135 dataset: `Dataset` to be concatenated. 1136 1137 Returns: 1138 Dataset: A `Dataset`. 1139 """ 1140 return ConcatenateDataset(self, dataset) 1141 1142 def prefetch(self, buffer_size): 1143 """Creates a `Dataset` that prefetches elements from this dataset. 1144 1145 Most dataset input pipelines should end with a call to `prefetch`. This 1146 allows later elements to be prepared while the current element is being 1147 processed. This often improves latency and throughput, at the cost of 1148 using additional memory to store prefetched elements. 1149 1150 Note: Like other `Dataset` methods, prefetch operates on the 1151 elements of the input dataset. It has no concept of examples vs. batches. 1152 `examples.prefetch(2)` will prefetch two elements (2 examples), 1153 while `examples.batch(20).prefetch(2)` will prefetch 2 elements 1154 (2 batches, of 20 examples each). 1155 1156 >>> dataset = tf.data.Dataset.range(3) 1157 >>> dataset = dataset.prefetch(2) 1158 >>> list(dataset.as_numpy_iterator()) 1159 [0, 1, 2] 1160 1161 Args: 1162 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the maximum 1163 number of elements that will be buffered when prefetching. If the value 1164 `tf.data.AUTOTUNE` is used, then the buffer size is dynamically tuned. 1165 Returns: 1166 Dataset: A `Dataset`. 1167 """ 1168 return PrefetchDataset(self, buffer_size) 1169 1170 @staticmethod 1171 def list_files(file_pattern, shuffle=None, seed=None): 1172 """A dataset of all files matching one or more glob patterns. 1173 1174 The `file_pattern` argument should be a small number of glob patterns. 1175 If your filenames have already been globbed, use 1176 `Dataset.from_tensor_slices(filenames)` instead, as re-globbing every 1177 filename with `list_files` may result in poor performance with remote 1178 storage systems. 1179 1180 Note: The default behavior of this method is to return filenames in 1181 a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False` 1182 to get results in a deterministic order. 1183 1184 Example: 1185 If we had the following files on our filesystem: 1186 1187 - /path/to/dir/a.txt 1188 - /path/to/dir/b.py 1189 - /path/to/dir/c.py 1190 1191 If we pass "/path/to/dir/*.py" as the directory, the dataset 1192 would produce: 1193 1194 - /path/to/dir/b.py 1195 - /path/to/dir/c.py 1196 1197 Args: 1198 file_pattern: A string, a list of strings, or a `tf.Tensor` of string type 1199 (scalar or vector), representing the filename glob (i.e. shell wildcard) 1200 pattern(s) that will be matched. 1201 shuffle: (Optional.) If `True`, the file names will be shuffled randomly. 1202 Defaults to `True`. 1203 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 1204 seed that will be used to create the distribution. See 1205 `tf.random.set_seed` for behavior. 1206 1207 Returns: 1208 Dataset: A `Dataset` of strings corresponding to file names. 1209 """ 1210 with ops.name_scope("list_files"): 1211 if shuffle is None: 1212 shuffle = True 1213 file_pattern = ops.convert_to_tensor( 1214 file_pattern, dtype=dtypes.string, name="file_pattern") 1215 matching_files = gen_io_ops.matching_files(file_pattern) 1216 1217 # Raise an exception if `file_pattern` does not match any files. 1218 condition = math_ops.greater(array_ops.shape(matching_files)[0], 0, 1219 name="match_not_empty") 1220 1221 message = math_ops.add( 1222 "No files matched pattern: ", 1223 string_ops.reduce_join(file_pattern, separator=", "), name="message") 1224 1225 assert_not_empty = control_flow_ops.Assert( 1226 condition, [message], summarize=1, name="assert_not_empty") 1227 with ops.control_dependencies([assert_not_empty]): 1228 matching_files = array_ops.identity(matching_files) 1229 1230 dataset = Dataset.from_tensor_slices(matching_files) 1231 if shuffle: 1232 # NOTE(mrry): The shuffle buffer size must be greater than zero, but the 1233 # list of files might be empty. 1234 buffer_size = math_ops.maximum( 1235 array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1) 1236 dataset = dataset.shuffle(buffer_size, seed=seed) 1237 return dataset 1238 1239 def repeat(self, count=None): 1240 """Repeats this dataset so each original value is seen `count` times. 1241 1242 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 1243 >>> dataset = dataset.repeat(3) 1244 >>> list(dataset.as_numpy_iterator()) 1245 [1, 2, 3, 1, 2, 3, 1, 2, 3] 1246 1247 Note: If this dataset is a function of global state (e.g. a random number 1248 generator), then different repetitions may produce different elements. 1249 1250 Args: 1251 count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 1252 number of times the dataset should be repeated. The default behavior (if 1253 `count` is `None` or `-1`) is for the dataset be repeated indefinitely. 1254 1255 Returns: 1256 Dataset: A `Dataset`. 1257 """ 1258 return RepeatDataset(self, count) 1259 1260 def enumerate(self, start=0): 1261 """Enumerates the elements of this dataset. 1262 1263 It is similar to python's `enumerate`. 1264 1265 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 1266 >>> dataset = dataset.enumerate(start=5) 1267 >>> for element in dataset.as_numpy_iterator(): 1268 ... print(element) 1269 (5, 1) 1270 (6, 2) 1271 (7, 3) 1272 1273 >>> # The nested structure of the input dataset determines the structure of 1274 >>> # elements in the resulting dataset. 1275 >>> dataset = tf.data.Dataset.from_tensor_slices([(7, 8), (9, 10)]) 1276 >>> dataset = dataset.enumerate() 1277 >>> for element in dataset.as_numpy_iterator(): 1278 ... print(element) 1279 (0, array([7, 8], dtype=int32)) 1280 (1, array([ 9, 10], dtype=int32)) 1281 1282 Args: 1283 start: A `tf.int64` scalar `tf.Tensor`, representing the start value for 1284 enumeration. 1285 1286 Returns: 1287 Dataset: A `Dataset`. 1288 """ 1289 1290 max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max 1291 return Dataset.zip((Dataset.range(start, max_value), self)) 1292 1293 def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None): 1294 """Randomly shuffles the elements of this dataset. 1295 1296 This dataset fills a buffer with `buffer_size` elements, then randomly 1297 samples elements from this buffer, replacing the selected elements with new 1298 elements. For perfect shuffling, a buffer size greater than or equal to the 1299 full size of the dataset is required. 1300 1301 For instance, if your dataset contains 10,000 elements but `buffer_size` is 1302 set to 1,000, then `shuffle` will initially select a random element from 1303 only the first 1,000 elements in the buffer. Once an element is selected, 1304 its space in the buffer is replaced by the next (i.e. 1,001-st) element, 1305 maintaining the 1,000 element buffer. 1306 1307 `reshuffle_each_iteration` controls whether the shuffle order should be 1308 different for each epoch. In TF 1.X, the idiomatic way to create epochs 1309 was through the `repeat` transformation: 1310 1311 >>> dataset = tf.data.Dataset.range(3) 1312 >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 1313 >>> dataset = dataset.repeat(2) # doctest: +SKIP 1314 [1, 0, 2, 1, 2, 0] 1315 1316 >>> dataset = tf.data.Dataset.range(3) 1317 >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=False) 1318 >>> dataset = dataset.repeat(2) # doctest: +SKIP 1319 [1, 0, 2, 1, 0, 2] 1320 1321 In TF 2.0, `tf.data.Dataset` objects are Python iterables which makes it 1322 possible to also create epochs through Python iteration: 1323 1324 >>> dataset = tf.data.Dataset.range(3) 1325 >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 1326 >>> list(dataset.as_numpy_iterator()) # doctest: +SKIP 1327 [1, 0, 2] 1328 >>> list(dataset.as_numpy_iterator()) # doctest: +SKIP 1329 [1, 2, 0] 1330 1331 >>> dataset = tf.data.Dataset.range(3) 1332 >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=False) 1333 >>> list(dataset.as_numpy_iterator()) # doctest: +SKIP 1334 [1, 0, 2] 1335 >>> list(dataset.as_numpy_iterator()) # doctest: +SKIP 1336 [1, 0, 2] 1337 1338 Args: 1339 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 1340 elements from this dataset from which the new dataset will sample. 1341 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 1342 seed that will be used to create the distribution. See 1343 `tf.random.set_seed` for behavior. 1344 reshuffle_each_iteration: (Optional.) A boolean, which if true indicates 1345 that the dataset should be pseudorandomly reshuffled each time it is 1346 iterated over. (Defaults to `True`.) 1347 1348 Returns: 1349 Dataset: A `Dataset`. 1350 """ 1351 return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration) 1352 1353 def cache(self, filename=""): 1354 """Caches the elements in this dataset. 1355 1356 The first time the dataset is iterated over, its elements will be cached 1357 either in the specified file or in memory. Subsequent iterations will 1358 use the cached data. 1359 1360 Note: For the cache to be finalized, the input dataset must be iterated 1361 through in its entirety. Otherwise, subsequent iterations will not use 1362 cached data. 1363 1364 >>> dataset = tf.data.Dataset.range(5) 1365 >>> dataset = dataset.map(lambda x: x**2) 1366 >>> dataset = dataset.cache() 1367 >>> # The first time reading through the data will generate the data using 1368 >>> # `range` and `map`. 1369 >>> list(dataset.as_numpy_iterator()) 1370 [0, 1, 4, 9, 16] 1371 >>> # Subsequent iterations read from the cache. 1372 >>> list(dataset.as_numpy_iterator()) 1373 [0, 1, 4, 9, 16] 1374 1375 When caching to a file, the cached data will persist across runs. Even the 1376 first iteration through the data will read from the cache file. Changing 1377 the input pipeline before the call to `.cache()` will have no effect until 1378 the cache file is removed or the filename is changed. 1379 1380 >>> dataset = tf.data.Dataset.range(5) 1381 >>> dataset = dataset.cache("/path/to/file") # doctest: +SKIP 1382 >>> list(dataset.as_numpy_iterator()) # doctest: +SKIP 1383 [0, 1, 2, 3, 4] 1384 >>> dataset = tf.data.Dataset.range(10) 1385 >>> dataset = dataset.cache("/path/to/file") # Same file! # doctest: +SKIP 1386 >>> list(dataset.as_numpy_iterator()) # doctest: +SKIP 1387 [0, 1, 2, 3, 4] 1388 1389 Note: `cache` will produce exactly the same elements during each iteration 1390 through the dataset. If you wish to randomize the iteration order, make sure 1391 to call `shuffle` *after* calling `cache`. 1392 1393 Args: 1394 filename: A `tf.string` scalar `tf.Tensor`, representing the name of a 1395 directory on the filesystem to use for caching elements in this Dataset. 1396 If a filename is not provided, the dataset will be cached in memory. 1397 1398 Returns: 1399 Dataset: A `Dataset`. 1400 """ 1401 return CacheDataset(self, filename) 1402 1403 def take(self, count): 1404 """Creates a `Dataset` with at most `count` elements from this dataset. 1405 1406 >>> dataset = tf.data.Dataset.range(10) 1407 >>> dataset = dataset.take(3) 1408 >>> list(dataset.as_numpy_iterator()) 1409 [0, 1, 2] 1410 1411 Args: 1412 count: A `tf.int64` scalar `tf.Tensor`, representing the number of 1413 elements of this dataset that should be taken to form the new dataset. 1414 If `count` is -1, or if `count` is greater than the size of this 1415 dataset, the new dataset will contain all elements of this dataset. 1416 1417 Returns: 1418 Dataset: A `Dataset`. 1419 """ 1420 return TakeDataset(self, count) 1421 1422 def skip(self, count): 1423 """Creates a `Dataset` that skips `count` elements from this dataset. 1424 1425 >>> dataset = tf.data.Dataset.range(10) 1426 >>> dataset = dataset.skip(7) 1427 >>> list(dataset.as_numpy_iterator()) 1428 [7, 8, 9] 1429 1430 Args: 1431 count: A `tf.int64` scalar `tf.Tensor`, representing the number of 1432 elements of this dataset that should be skipped to form the new dataset. 1433 If `count` is greater than the size of this dataset, the new dataset 1434 will contain no elements. If `count` is -1, skips the entire dataset. 1435 1436 Returns: 1437 Dataset: A `Dataset`. 1438 """ 1439 return SkipDataset(self, count) 1440 1441 def shard(self, num_shards, index): 1442 """Creates a `Dataset` that includes only 1/`num_shards` of this dataset. 1443 1444 `shard` is deterministic. The Dataset produced by `A.shard(n, i)` will 1445 contain all elements of A whose index mod n = i. 1446 1447 >>> A = tf.data.Dataset.range(10) 1448 >>> B = A.shard(num_shards=3, index=0) 1449 >>> list(B.as_numpy_iterator()) 1450 [0, 3, 6, 9] 1451 >>> C = A.shard(num_shards=3, index=1) 1452 >>> list(C.as_numpy_iterator()) 1453 [1, 4, 7] 1454 >>> D = A.shard(num_shards=3, index=2) 1455 >>> list(D.as_numpy_iterator()) 1456 [2, 5, 8] 1457 1458 This dataset operator is very useful when running distributed training, as 1459 it allows each worker to read a unique subset. 1460 1461 When reading a single input file, you can shard elements as follows: 1462 1463 ```python 1464 d = tf.data.TFRecordDataset(input_file) 1465 d = d.shard(num_workers, worker_index) 1466 d = d.repeat(num_epochs) 1467 d = d.shuffle(shuffle_buffer_size) 1468 d = d.map(parser_fn, num_parallel_calls=num_map_threads) 1469 ``` 1470 1471 Important caveats: 1472 1473 - Be sure to shard before you use any randomizing operator (such as 1474 shuffle). 1475 - Generally it is best if the shard operator is used early in the dataset 1476 pipeline. For example, when reading from a set of TFRecord files, shard 1477 before converting the dataset to input samples. This avoids reading every 1478 file on every worker. The following is an example of an efficient 1479 sharding strategy within a complete pipeline: 1480 1481 ```python 1482 d = Dataset.list_files(pattern) 1483 d = d.shard(num_workers, worker_index) 1484 d = d.repeat(num_epochs) 1485 d = d.shuffle(shuffle_buffer_size) 1486 d = d.interleave(tf.data.TFRecordDataset, 1487 cycle_length=num_readers, block_length=1) 1488 d = d.map(parser_fn, num_parallel_calls=num_map_threads) 1489 ``` 1490 1491 Args: 1492 num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of 1493 shards operating in parallel. 1494 index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. 1495 1496 Returns: 1497 Dataset: A `Dataset`. 1498 1499 Raises: 1500 InvalidArgumentError: if `num_shards` or `index` are illegal values. 1501 1502 Note: error checking is done on a best-effort basis, and errors aren't 1503 guaranteed to be caught upon dataset creation. (e.g. providing in a 1504 placeholder tensor bypasses the early checking, and will instead result 1505 in an error during a session.run call.) 1506 """ 1507 return ShardDataset(self, num_shards, index) 1508 1509 def batch(self, batch_size, drop_remainder=False, num_parallel_calls=None): 1510 """Combines consecutive elements of this dataset into batches. 1511 1512 >>> dataset = tf.data.Dataset.range(8) 1513 >>> dataset = dataset.batch(3) 1514 >>> list(dataset.as_numpy_iterator()) 1515 [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])] 1516 1517 >>> dataset = tf.data.Dataset.range(8) 1518 >>> dataset = dataset.batch(3, drop_remainder=True) 1519 >>> list(dataset.as_numpy_iterator()) 1520 [array([0, 1, 2]), array([3, 4, 5])] 1521 1522 The components of the resulting element will have an additional outer 1523 dimension, which will be `batch_size` (or `N % batch_size` for the last 1524 element if `batch_size` does not divide the number of input elements `N` 1525 evenly and `drop_remainder` is `False`). If your program depends on the 1526 batches having the same outer dimension, you should set the `drop_remainder` 1527 argument to `True` to prevent the smaller batch from being produced. 1528 1529 Args: 1530 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 1531 consecutive elements of this dataset to combine in a single batch. 1532 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 1533 whether the last batch should be dropped in the case it has fewer than 1534 `batch_size` elements; the default behavior is not to drop the smaller 1535 batch. 1536 num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`, 1537 representing the number of batches to compute asynchronously in 1538 parallel. 1539 If not specified, batches will be computed sequentially. If the value 1540 `tf.data.AUTOTUNE` is used, then the number of parallel 1541 calls is set dynamically based on available resources. 1542 1543 Returns: 1544 Dataset: A `Dataset`. 1545 """ 1546 if num_parallel_calls is None: 1547 return BatchDataset(self, batch_size, drop_remainder) 1548 else: 1549 return ParallelBatchDataset(self, batch_size, drop_remainder, 1550 num_parallel_calls) 1551 1552 def padded_batch(self, 1553 batch_size, 1554 padded_shapes=None, 1555 padding_values=None, 1556 drop_remainder=False): 1557 """Combines consecutive elements of this dataset into padded batches. 1558 1559 This transformation combines multiple consecutive elements of the input 1560 dataset into a single element. 1561 1562 Like `tf.data.Dataset.batch`, the components of the resulting element will 1563 have an additional outer dimension, which will be `batch_size` (or 1564 `N % batch_size` for the last element if `batch_size` does not divide the 1565 number of input elements `N` evenly and `drop_remainder` is `False`). If 1566 your program depends on the batches having the same outer dimension, you 1567 should set the `drop_remainder` argument to `True` to prevent the smaller 1568 batch from being produced. 1569 1570 Unlike `tf.data.Dataset.batch`, the input elements to be batched may have 1571 different shapes, and this transformation will pad each component to the 1572 respective shape in `padded_shapes`. The `padded_shapes` argument 1573 determines the resulting shape for each dimension of each component in an 1574 output element: 1575 1576 * If the dimension is a constant, the component will be padded out to that 1577 length in that dimension. 1578 * If the dimension is unknown, the component will be padded out to the 1579 maximum length of all elements in that dimension. 1580 1581 >>> A = (tf.data.Dataset 1582 ... .range(1, 5, output_type=tf.int32) 1583 ... .map(lambda x: tf.fill([x], x))) 1584 >>> # Pad to the smallest per-batch size that fits all elements. 1585 >>> B = A.padded_batch(2) 1586 >>> for element in B.as_numpy_iterator(): 1587 ... print(element) 1588 [[1 0] 1589 [2 2]] 1590 [[3 3 3 0] 1591 [4 4 4 4]] 1592 >>> # Pad to a fixed size. 1593 >>> C = A.padded_batch(2, padded_shapes=5) 1594 >>> for element in C.as_numpy_iterator(): 1595 ... print(element) 1596 [[1 0 0 0 0] 1597 [2 2 0 0 0]] 1598 [[3 3 3 0 0] 1599 [4 4 4 4 0]] 1600 >>> # Pad with a custom value. 1601 >>> D = A.padded_batch(2, padded_shapes=5, padding_values=-1) 1602 >>> for element in D.as_numpy_iterator(): 1603 ... print(element) 1604 [[ 1 -1 -1 -1 -1] 1605 [ 2 2 -1 -1 -1]] 1606 [[ 3 3 3 -1 -1] 1607 [ 4 4 4 4 -1]] 1608 >>> # Components of nested elements can be padded independently. 1609 >>> elements = [([1, 2, 3], [10]), 1610 ... ([4, 5], [11, 12])] 1611 >>> dataset = tf.data.Dataset.from_generator( 1612 ... lambda: iter(elements), (tf.int32, tf.int32)) 1613 >>> # Pad the first component of the tuple to length 4, and the second 1614 >>> # component to the smallest size that fits. 1615 >>> dataset = dataset.padded_batch(2, 1616 ... padded_shapes=([4], [None]), 1617 ... padding_values=(-1, 100)) 1618 >>> list(dataset.as_numpy_iterator()) 1619 [(array([[ 1, 2, 3, -1], [ 4, 5, -1, -1]], dtype=int32), 1620 array([[ 10, 100], [ 11, 12]], dtype=int32))] 1621 >>> # Pad with a single value and multiple components. 1622 >>> E = tf.data.Dataset.zip((A, A)).padded_batch(2, padding_values=-1) 1623 >>> for element in E.as_numpy_iterator(): 1624 ... print(element) 1625 (array([[ 1, -1], 1626 [ 2, 2]], dtype=int32), array([[ 1, -1], 1627 [ 2, 2]], dtype=int32)) 1628 (array([[ 3, 3, 3, -1], 1629 [ 4, 4, 4, 4]], dtype=int32), array([[ 3, 3, 3, -1], 1630 [ 4, 4, 4, 4]], dtype=int32)) 1631 1632 See also `tf.data.experimental.dense_to_sparse_batch`, which combines 1633 elements that may have different shapes into a `tf.sparse.SparseTensor`. 1634 1635 Args: 1636 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 1637 consecutive elements of this dataset to combine in a single batch. 1638 padded_shapes: (Optional.) A nested structure of `tf.TensorShape` or 1639 `tf.int64` vector tensor-like objects representing the shape to which 1640 the respective component of each input element should be padded prior 1641 to batching. Any unknown dimensions will be padded to the maximum size 1642 of that dimension in each batch. If unset, all dimensions of all 1643 components are padded to the maximum size in the batch. `padded_shapes` 1644 must be set if any component has an unknown rank. 1645 padding_values: (Optional.) A nested structure of scalar-shaped 1646 `tf.Tensor`, representing the padding values to use for the respective 1647 components. None represents that the nested structure should be padded 1648 with default values. Defaults are `0` for numeric types and the empty 1649 string for string types. The `padding_values` should have the 1650 same structure as the input dataset. If `padding_values` is a single 1651 element and the input dataset has multiple components, then the same 1652 `padding_values` will be used to pad every component of the dataset. 1653 If `padding_values` is a scalar, then its value will be broadcasted 1654 to match the shape of each component. 1655 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 1656 whether the last batch should be dropped in the case it has fewer than 1657 `batch_size` elements; the default behavior is not to drop the smaller 1658 batch. 1659 1660 Returns: 1661 Dataset: A `Dataset`. 1662 1663 Raises: 1664 ValueError: If a component has an unknown rank, and the `padded_shapes` 1665 argument is not set. 1666 """ 1667 if padded_shapes is None: 1668 padded_shapes = get_legacy_output_shapes(self) 1669 # A `tf.TensorShape` is only false if its *rank* is unknown: 1670 # bool(tf.TensorShape(None)) is False 1671 if not all(nest.flatten(padded_shapes)): 1672 raise ValueError("You must set the `padded_shapes` argument to " 1673 "`Dataset.padded_batch` if any component of its " 1674 "input has an unknown rank") 1675 return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values, 1676 drop_remainder) 1677 1678 def map(self, map_func, num_parallel_calls=None, deterministic=None): 1679 """Maps `map_func` across the elements of this dataset. 1680 1681 This transformation applies `map_func` to each element of this dataset, and 1682 returns a new dataset containing the transformed elements, in the same 1683 order as they appeared in the input. `map_func` can be used to change both 1684 the values and the structure of a dataset's elements. For example, adding 1 1685 to each element, or projecting a subset of element components. 1686 1687 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] 1688 >>> dataset = dataset.map(lambda x: x + 1) 1689 >>> list(dataset.as_numpy_iterator()) 1690 [2, 3, 4, 5, 6] 1691 1692 The input signature of `map_func` is determined by the structure of each 1693 element in this dataset. 1694 1695 >>> dataset = Dataset.range(5) 1696 >>> # `map_func` takes a single argument of type `tf.Tensor` with the same 1697 >>> # shape and dtype. 1698 >>> result = dataset.map(lambda x: x + 1) 1699 1700 >>> # Each element is a tuple containing two `tf.Tensor` objects. 1701 >>> elements = [(1, "foo"), (2, "bar"), (3, "baz")] 1702 >>> dataset = tf.data.Dataset.from_generator( 1703 ... lambda: elements, (tf.int32, tf.string)) 1704 >>> # `map_func` takes two arguments of type `tf.Tensor`. This function 1705 >>> # projects out just the first component. 1706 >>> result = dataset.map(lambda x_int, y_str: x_int) 1707 >>> list(result.as_numpy_iterator()) 1708 [1, 2, 3] 1709 1710 >>> # Each element is a dictionary mapping strings to `tf.Tensor` objects. 1711 >>> elements = ([{"a": 1, "b": "foo"}, 1712 ... {"a": 2, "b": "bar"}, 1713 ... {"a": 3, "b": "baz"}]) 1714 >>> dataset = tf.data.Dataset.from_generator( 1715 ... lambda: elements, {"a": tf.int32, "b": tf.string}) 1716 >>> # `map_func` takes a single argument of type `dict` with the same keys 1717 >>> # as the elements. 1718 >>> result = dataset.map(lambda d: str(d["a"]) + d["b"]) 1719 1720 The value or values returned by `map_func` determine the structure of each 1721 element in the returned dataset. 1722 1723 >>> dataset = tf.data.Dataset.range(3) 1724 >>> # `map_func` returns two `tf.Tensor` objects. 1725 >>> def g(x): 1726 ... return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"]) 1727 >>> result = dataset.map(g) 1728 >>> result.element_spec 1729 (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(3,), \ 1730dtype=tf.string, name=None)) 1731 >>> # Python primitives, lists, and NumPy arrays are implicitly converted to 1732 >>> # `tf.Tensor`. 1733 >>> def h(x): 1734 ... return 37.0, ["Foo", "Bar"], np.array([1.0, 2.0], dtype=np.float64) 1735 >>> result = dataset.map(h) 1736 >>> result.element_spec 1737 (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(2,), \ 1738dtype=tf.string, name=None), TensorSpec(shape=(2,), dtype=tf.float64, \ 1739name=None)) 1740 >>> # `map_func` can return nested structures. 1741 >>> def i(x): 1742 ... return (37.0, [42, 16]), "foo" 1743 >>> result = dataset.map(i) 1744 >>> result.element_spec 1745 ((TensorSpec(shape=(), dtype=tf.float32, name=None), 1746 TensorSpec(shape=(2,), dtype=tf.int32, name=None)), 1747 TensorSpec(shape=(), dtype=tf.string, name=None)) 1748 1749 `map_func` can accept as arguments and return any type of dataset element. 1750 1751 Note that irrespective of the context in which `map_func` is defined (eager 1752 vs. graph), tf.data traces the function and executes it as a graph. To use 1753 Python code inside of the function you have a few options: 1754 1755 1) Rely on AutoGraph to convert Python code into an equivalent graph 1756 computation. The downside of this approach is that AutoGraph can convert 1757 some but not all Python code. 1758 1759 2) Use `tf.py_function`, which allows you to write arbitrary Python code but 1760 will generally result in worse performance than 1). For example: 1761 1762 >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world']) 1763 >>> # transform a string tensor to upper case string using a Python function 1764 >>> def upper_case_fn(t: tf.Tensor): 1765 ... return t.numpy().decode('utf-8').upper() 1766 >>> d = d.map(lambda x: tf.py_function(func=upper_case_fn, 1767 ... inp=[x], Tout=tf.string)) 1768 >>> list(d.as_numpy_iterator()) 1769 [b'HELLO', b'WORLD'] 1770 1771 3) Use `tf.numpy_function`, which also allows you to write arbitrary 1772 Python code. Note that `tf.py_function` accepts `tf.Tensor` whereas 1773 `tf.numpy_function` accepts numpy arrays and returns only numpy arrays. 1774 For example: 1775 1776 >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world']) 1777 >>> def upper_case_fn(t: np.ndarray): 1778 ... return t.decode('utf-8').upper() 1779 >>> d = d.map(lambda x: tf.numpy_function(func=upper_case_fn, 1780 ... inp=[x], Tout=tf.string)) 1781 >>> list(d.as_numpy_iterator()) 1782 [b'HELLO', b'WORLD'] 1783 1784 Note that the use of `tf.numpy_function` and `tf.py_function` 1785 in general precludes the possibility of executing user-defined 1786 transformations in parallel (because of Python GIL). 1787 1788 Performance can often be improved by setting `num_parallel_calls` so that 1789 `map` will use multiple threads to process elements. If deterministic order 1790 isn't required, it can also improve performance to set 1791 `deterministic=False`. 1792 1793 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] 1794 >>> dataset = dataset.map(lambda x: x + 1, 1795 ... num_parallel_calls=tf.data.AUTOTUNE, 1796 ... deterministic=False) 1797 1798 The order of elements yielded by this transformation is deterministic if 1799 `deterministic=True`. If `map_func` contains stateful operations and 1800 `num_parallel_calls > 1`, the order in which that state is accessed is 1801 undefined, so the values of output elements may not be deterministic 1802 regardless of the `deterministic` flag value. 1803 1804 Args: 1805 map_func: A function mapping a dataset element to another dataset element. 1806 num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`, 1807 representing the number elements to process asynchronously in parallel. 1808 If not specified, elements will be processed sequentially. If the value 1809 `tf.data.AUTOTUNE` is used, then the number of parallel 1810 calls is set dynamically based on available CPU. 1811 deterministic: (Optional.) When `num_parallel_calls` is specified, this 1812 boolean controls the order in which the transformation produces 1813 elements. If set to `False`, the transformation is allowed to yield 1814 elements out of order to trade determinism for performance. If not 1815 specified, the `tf.data.Options.experimental_deterministic` option 1816 (`True` by default) controls the behavior. 1817 1818 Returns: 1819 Dataset: A `Dataset`. 1820 """ 1821 if num_parallel_calls is None: 1822 if deterministic is not None: 1823 warnings.warn("The `deterministic` argument has no effect unless the " 1824 "`num_parallel_calls` argument is specified.") 1825 return MapDataset(self, map_func, preserve_cardinality=True) 1826 else: 1827 return ParallelMapDataset( 1828 self, 1829 map_func, 1830 num_parallel_calls, 1831 deterministic, 1832 preserve_cardinality=True) 1833 1834 def flat_map(self, map_func): 1835 """Maps `map_func` across this dataset and flattens the result. 1836 1837 Use `flat_map` if you want to make sure that the order of your dataset 1838 stays the same. For example, to flatten a dataset of batches into a 1839 dataset of their elements: 1840 1841 >>> dataset = tf.data.Dataset.from_tensor_slices( 1842 ... [[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 1843 >>> dataset = dataset.flat_map(lambda x: Dataset.from_tensor_slices(x)) 1844 >>> list(dataset.as_numpy_iterator()) 1845 [1, 2, 3, 4, 5, 6, 7, 8, 9] 1846 1847 `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since 1848 `flat_map` produces the same output as 1849 `tf.data.Dataset.interleave(cycle_length=1)` 1850 1851 Args: 1852 map_func: A function mapping a dataset element to a dataset. 1853 1854 Returns: 1855 Dataset: A `Dataset`. 1856 """ 1857 return FlatMapDataset(self, map_func) 1858 1859 def interleave(self, 1860 map_func, 1861 cycle_length=None, 1862 block_length=None, 1863 num_parallel_calls=None, 1864 deterministic=None): 1865 """Maps `map_func` across this dataset, and interleaves the results. 1866 1867 For example, you can use `Dataset.interleave()` to process many input files 1868 concurrently: 1869 1870 >>> # Preprocess 4 files concurrently, and interleave blocks of 16 records 1871 >>> # from each file. 1872 >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt", 1873 ... "/var/data/file3.txt", "/var/data/file4.txt"] 1874 >>> dataset = tf.data.Dataset.from_tensor_slices(filenames) 1875 >>> def parse_fn(filename): 1876 ... return tf.data.Dataset.range(10) 1877 >>> dataset = dataset.interleave(lambda x: 1878 ... tf.data.TextLineDataset(x).map(parse_fn, num_parallel_calls=1), 1879 ... cycle_length=4, block_length=16) 1880 1881 The `cycle_length` and `block_length` arguments control the order in which 1882 elements are produced. `cycle_length` controls the number of input elements 1883 that are processed concurrently. If you set `cycle_length` to 1, this 1884 transformation will handle one input element at a time, and will produce 1885 identical results to `tf.data.Dataset.flat_map`. In general, 1886 this transformation will apply `map_func` to `cycle_length` input elements, 1887 open iterators on the returned `Dataset` objects, and cycle through them 1888 producing `block_length` consecutive elements from each iterator, and 1889 consuming the next input element each time it reaches the end of an 1890 iterator. 1891 1892 For example: 1893 1894 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] 1895 >>> # NOTE: New lines indicate "block" boundaries. 1896 >>> dataset = dataset.interleave( 1897 ... lambda x: Dataset.from_tensors(x).repeat(6), 1898 ... cycle_length=2, block_length=4) 1899 >>> list(dataset.as_numpy_iterator()) 1900 [1, 1, 1, 1, 1901 2, 2, 2, 2, 1902 1, 1, 1903 2, 2, 1904 3, 3, 3, 3, 1905 4, 4, 4, 4, 1906 3, 3, 1907 4, 4, 1908 5, 5, 5, 5, 1909 5, 5] 1910 1911 Note: The order of elements yielded by this transformation is 1912 deterministic, as long as `map_func` is a pure function and 1913 `deterministic=True`. If `map_func` contains any stateful operations, the 1914 order in which that state is accessed is undefined. 1915 1916 Performance can often be improved by setting `num_parallel_calls` so that 1917 `interleave` will use multiple threads to fetch elements. If determinism 1918 isn't required, it can also improve performance to set 1919 `deterministic=False`. 1920 1921 >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt", 1922 ... "/var/data/file3.txt", "/var/data/file4.txt"] 1923 >>> dataset = tf.data.Dataset.from_tensor_slices(filenames) 1924 >>> dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x), 1925 ... cycle_length=4, num_parallel_calls=tf.data.AUTOTUNE, 1926 ... deterministic=False) 1927 1928 Args: 1929 map_func: A function mapping a dataset element to a dataset. 1930 cycle_length: (Optional.) The number of input elements that will be 1931 processed concurrently. If not set, the tf.data runtime decides what it 1932 should be based on available CPU. If `num_parallel_calls` is set to 1933 `tf.data.AUTOTUNE`, the `cycle_length` argument identifies 1934 the maximum degree of parallelism. 1935 block_length: (Optional.) The number of consecutive elements to produce 1936 from each input element before cycling to another input element. If not 1937 set, defaults to 1. 1938 num_parallel_calls: (Optional.) If specified, the implementation creates a 1939 threadpool, which is used to fetch inputs from cycle elements 1940 asynchronously and in parallel. The default behavior is to fetch inputs 1941 from cycle elements synchronously with no parallelism. If the value 1942 `tf.data.AUTOTUNE` is used, then the number of parallel 1943 calls is set dynamically based on available CPU. 1944 deterministic: (Optional.) When `num_parallel_calls` is specified, this 1945 boolean controls the order in which the transformation produces 1946 elements. If set to `False`, the transformation is allowed to yield 1947 elements out of order to trade determinism for performance. If not 1948 specified, the `tf.data.Options.experimental_deterministic` option 1949 (`True` by default) controls the behavior. 1950 1951 Returns: 1952 Dataset: A `Dataset`. 1953 """ 1954 if block_length is None: 1955 block_length = 1 1956 1957 if cycle_length is None: 1958 cycle_length = AUTOTUNE 1959 1960 if num_parallel_calls is None: 1961 if deterministic is not None: 1962 warnings.warn("The `deterministic` argument has no effect unless the " 1963 "`num_parallel_calls` argument is specified.") 1964 return InterleaveDataset(self, map_func, cycle_length, block_length) 1965 else: 1966 return ParallelInterleaveDataset( 1967 self, 1968 map_func, 1969 cycle_length, 1970 block_length, 1971 num_parallel_calls, 1972 deterministic=deterministic) 1973 1974 def filter(self, predicate): 1975 """Filters this dataset according to `predicate`. 1976 1977 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 1978 >>> dataset = dataset.filter(lambda x: x < 3) 1979 >>> list(dataset.as_numpy_iterator()) 1980 [1, 2] 1981 >>> # `tf.math.equal(x, y)` is required for equality comparison 1982 >>> def filter_fn(x): 1983 ... return tf.math.equal(x, 1) 1984 >>> dataset = dataset.filter(filter_fn) 1985 >>> list(dataset.as_numpy_iterator()) 1986 [1] 1987 1988 Args: 1989 predicate: A function mapping a dataset element to a boolean. 1990 1991 Returns: 1992 Dataset: The `Dataset` containing the elements of this dataset for which 1993 `predicate` is `True`. 1994 """ 1995 return FilterDataset(self, predicate) 1996 1997 def apply(self, transformation_func): 1998 """Applies a transformation function to this dataset. 1999 2000 `apply` enables chaining of custom `Dataset` transformations, which are 2001 represented as functions that take one `Dataset` argument and return a 2002 transformed `Dataset`. 2003 2004 >>> dataset = tf.data.Dataset.range(100) 2005 >>> def dataset_fn(ds): 2006 ... return ds.filter(lambda x: x < 5) 2007 >>> dataset = dataset.apply(dataset_fn) 2008 >>> list(dataset.as_numpy_iterator()) 2009 [0, 1, 2, 3, 4] 2010 2011 Args: 2012 transformation_func: A function that takes one `Dataset` argument and 2013 returns a `Dataset`. 2014 2015 Returns: 2016 Dataset: The `Dataset` returned by applying `transformation_func` to this 2017 dataset. 2018 """ 2019 dataset = transformation_func(self) 2020 if not isinstance(dataset, DatasetV2): 2021 raise TypeError( 2022 "`transformation_func` must return a Dataset. Got {}.".format( 2023 dataset)) 2024 dataset._input_datasets = [self] # pylint: disable=protected-access 2025 return dataset 2026 2027 def window(self, size, shift=None, stride=1, drop_remainder=False): 2028 """Combines (nests of) input elements into a dataset of (nests of) windows. 2029 2030 A "window" is a finite dataset of flat elements of size `size` (or possibly 2031 fewer if there are not enough input elements to fill the window and 2032 `drop_remainder` evaluates to `False`). 2033 2034 The `shift` argument determines the number of input elements by which the 2035 window moves on each iteration. If windows and elements are both numbered 2036 starting at 0, the first element in window `k` will be element `k * shift` 2037 of the input dataset. In particular, the first element of the first window 2038 will always be the first element of the input dataset. 2039 2040 The `stride` argument determines the stride of the input elements, and the 2041 `shift` argument determines the shift of the window. 2042 2043 For example: 2044 2045 >>> dataset = tf.data.Dataset.range(7).window(2) 2046 >>> for window in dataset: 2047 ... print(list(window.as_numpy_iterator())) 2048 [0, 1] 2049 [2, 3] 2050 [4, 5] 2051 [6] 2052 >>> dataset = tf.data.Dataset.range(7).window(3, 2, 1, True) 2053 >>> for window in dataset: 2054 ... print(list(window.as_numpy_iterator())) 2055 [0, 1, 2] 2056 [2, 3, 4] 2057 [4, 5, 6] 2058 >>> dataset = tf.data.Dataset.range(7).window(3, 1, 2, True) 2059 >>> for window in dataset: 2060 ... print(list(window.as_numpy_iterator())) 2061 [0, 2, 4] 2062 [1, 3, 5] 2063 [2, 4, 6] 2064 2065 Note that when the `window` transformation is applied to a dataset of 2066 nested elements, it produces a dataset of nested windows. 2067 2068 >>> nested = ([1, 2, 3, 4], [5, 6, 7, 8]) 2069 >>> dataset = tf.data.Dataset.from_tensor_slices(nested).window(2) 2070 >>> for window in dataset: 2071 ... def to_numpy(ds): 2072 ... return list(ds.as_numpy_iterator()) 2073 ... print(tuple(to_numpy(component) for component in window)) 2074 ([1, 2], [5, 6]) 2075 ([3, 4], [7, 8]) 2076 2077 >>> dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3, 4]}) 2078 >>> dataset = dataset.window(2) 2079 >>> for window in dataset: 2080 ... def to_numpy(ds): 2081 ... return list(ds.as_numpy_iterator()) 2082 ... print({'a': to_numpy(window['a'])}) 2083 {'a': [1, 2]} 2084 {'a': [3, 4]} 2085 2086 Args: 2087 size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements 2088 of the input dataset to combine into a window. Must be positive. 2089 shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 2090 number of input elements by which the window moves in each iteration. 2091 Defaults to `size`. Must be positive. 2092 stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 2093 stride of the input elements in the sliding window. Must be positive. 2094 The default value of 1 means "retain every input element". 2095 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 2096 whether the last windows should be dropped if their size is smaller than 2097 `size`. 2098 2099 Returns: 2100 Dataset: A `Dataset` of (nests of) windows -- a finite datasets of flat 2101 elements created from the (nests of) input elements. 2102 2103 """ 2104 if shift is None: 2105 shift = size 2106 return WindowDataset(self, size, shift, stride, drop_remainder) 2107 2108 def reduce(self, initial_state, reduce_func): 2109 """Reduces the input dataset to a single element. 2110 2111 The transformation calls `reduce_func` successively on every element of 2112 the input dataset until the dataset is exhausted, aggregating information in 2113 its internal state. The `initial_state` argument is used for the initial 2114 state and the final state is returned as the result. 2115 2116 >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1).numpy() 2117 5 2118 >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y).numpy() 2119 10 2120 2121 Args: 2122 initial_state: An element representing the initial state of the 2123 transformation. 2124 reduce_func: A function that maps `(old_state, input_element)` to 2125 `new_state`. It must take two arguments and return a new element 2126 The structure of `new_state` must match the structure of 2127 `initial_state`. 2128 2129 Returns: 2130 A dataset element corresponding to the final state of the transformation. 2131 2132 """ 2133 2134 with ops.name_scope("initial_state"): 2135 initial_state = structure.normalize_element(initial_state) 2136 state_structure = structure.type_spec_from_value(initial_state) 2137 2138 # Iteratively rerun the reduce function until reaching a fixed point on 2139 # `state_structure`. 2140 need_to_rerun = True 2141 while need_to_rerun: 2142 2143 wrapped_func = StructuredFunctionWrapper( 2144 reduce_func, 2145 "reduce()", 2146 input_structure=(state_structure, self.element_spec), 2147 add_to_graph=False) 2148 2149 # Extract and validate class information from the returned values. 2150 output_classes = wrapped_func.output_classes 2151 state_classes = nest.map_structure( 2152 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 2153 state_structure) 2154 for new_state_class, state_class in zip( 2155 nest.flatten(output_classes), nest.flatten(state_classes)): 2156 if not issubclass(new_state_class, state_class): 2157 raise TypeError( 2158 "The element classes for the new state must match the initial " 2159 "state. Expected %s; got %s." % 2160 (state_classes, wrapped_func.output_classes)) 2161 2162 # Extract and validate type information from the returned values. 2163 output_types = wrapped_func.output_types 2164 state_types = nest.map_structure( 2165 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 2166 state_structure) 2167 for new_state_type, state_type in zip( 2168 nest.flatten(output_types), nest.flatten(state_types)): 2169 if new_state_type != state_type: 2170 raise TypeError( 2171 "The element types for the new state must match the initial " 2172 "state. Expected %s; got %s." % 2173 (state_types, wrapped_func.output_types)) 2174 2175 # Extract shape information from the returned values. 2176 output_shapes = wrapped_func.output_shapes 2177 state_shapes = nest.map_structure( 2178 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 2179 state_structure) 2180 flat_state_shapes = nest.flatten(state_shapes) 2181 flat_new_state_shapes = nest.flatten(output_shapes) 2182 weakened_state_shapes = [ 2183 original.most_specific_compatible_shape(new) 2184 for original, new in zip(flat_state_shapes, flat_new_state_shapes) 2185 ] 2186 2187 need_to_rerun = False 2188 for original_shape, weakened_shape in zip(flat_state_shapes, 2189 weakened_state_shapes): 2190 if original_shape.ndims is not None and ( 2191 weakened_shape.ndims is None or 2192 original_shape.as_list() != weakened_shape.as_list()): 2193 need_to_rerun = True 2194 break 2195 2196 if need_to_rerun: 2197 # TODO(b/110122868): Support a "most specific compatible structure" 2198 # method for combining structures, to avoid using legacy structures 2199 # here. 2200 state_structure = structure.convert_legacy_structure( 2201 state_types, 2202 nest.pack_sequence_as(state_shapes, weakened_state_shapes), 2203 state_classes) 2204 2205 reduce_func = wrapped_func.function 2206 reduce_func.add_to_graph(ops.get_default_graph()) 2207 2208 dataset = self._apply_options() 2209 2210 # pylint: disable=protected-access 2211 return structure.from_compatible_tensor_list( 2212 state_structure, 2213 gen_dataset_ops.reduce_dataset( 2214 dataset._variant_tensor, 2215 structure.to_tensor_list(state_structure, initial_state), 2216 reduce_func.captured_inputs, 2217 f=reduce_func, 2218 output_shapes=structure.get_flat_tensor_shapes(state_structure), 2219 output_types=structure.get_flat_tensor_types(state_structure))) 2220 2221 def unbatch(self): 2222 """Splits elements of a dataset into multiple elements. 2223 2224 For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, 2225 where `B` may vary for each input element, then for each element in the 2226 dataset, the unbatched dataset will contain `B` consecutive elements 2227 of shape `[a0, a1, ...]`. 2228 2229 >>> elements = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ] 2230 >>> dataset = tf.data.Dataset.from_generator(lambda: elements, tf.int64) 2231 >>> dataset = dataset.unbatch() 2232 >>> list(dataset.as_numpy_iterator()) 2233 [1, 2, 3, 1, 2, 1, 2, 3, 4] 2234 2235 Note: `unbatch` requires a data copy to slice up the batched tensor into 2236 smaller, unbatched tensors. When optimizing performance, try to avoid 2237 unnecessary usage of `unbatch`. 2238 2239 Returns: 2240 A `Dataset`. 2241 """ 2242 normalized_dataset = normalize_to_dense(self) 2243 return _UnbatchDataset(normalized_dataset) 2244 2245 def with_options(self, options): 2246 """Returns a new `tf.data.Dataset` with the given options set. 2247 2248 The options are "global" in the sense they apply to the entire dataset. 2249 If options are set multiple times, they are merged as long as different 2250 options do not use different non-default values. 2251 2252 >>> ds = tf.data.Dataset.range(5) 2253 >>> ds = ds.interleave(lambda x: tf.data.Dataset.range(5), 2254 ... cycle_length=3, 2255 ... num_parallel_calls=3) 2256 >>> options = tf.data.Options() 2257 >>> # This will make the interleave order non-deterministic. 2258 >>> options.experimental_deterministic = False 2259 >>> ds = ds.with_options(options) 2260 2261 Args: 2262 options: A `tf.data.Options` that identifies the options the use. 2263 2264 Returns: 2265 Dataset: A `Dataset` with the given options. 2266 2267 Raises: 2268 ValueError: when an option is set more than once to a non-default value 2269 """ 2270 return _OptionsDataset(self, options) 2271 2272 def cardinality(self): 2273 """Returns the cardinality of the dataset, if known. 2274 2275 `cardinality` may return `tf.data.INFINITE_CARDINALITY` if the dataset 2276 contains an infinite number of elements or `tf.data.UNKNOWN_CARDINALITY` if 2277 the analysis fails to determine the number of elements in the dataset 2278 (e.g. when the dataset source is a file). 2279 2280 >>> dataset = tf.data.Dataset.range(42) 2281 >>> print(dataset.cardinality().numpy()) 2282 42 2283 >>> dataset = dataset.repeat() 2284 >>> cardinality = dataset.cardinality() 2285 >>> print((cardinality == tf.data.INFINITE_CARDINALITY).numpy()) 2286 True 2287 >>> dataset = dataset.filter(lambda x: True) 2288 >>> cardinality = dataset.cardinality() 2289 >>> print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy()) 2290 True 2291 2292 Returns: 2293 A scalar `tf.int64` `Tensor` representing the cardinality of the dataset. 2294 If the cardinality is infinite or unknown, `cardinality` returns the 2295 named constants `tf.data.INFINITE_CARDINALITY` and 2296 `tf.data.UNKNOWN_CARDINALITY` respectively. 2297 """ 2298 return gen_dataset_ops.dataset_cardinality(self._variant_tensor) 2299 2300 2301@tf_export(v1=["data.Dataset"]) 2302class DatasetV1(DatasetV2): 2303 """Represents a potentially large set of elements. 2304 2305 A `Dataset` can be used to represent an input pipeline as a 2306 collection of elements and a "logical plan" of transformations that act on 2307 those elements. 2308 """ 2309 2310 def __init__(self): 2311 try: 2312 variant_tensor = self._as_variant_tensor() 2313 except AttributeError as e: 2314 if "_as_variant_tensor" in str(e): 2315 raise AttributeError("Please use _variant_tensor instead of " 2316 "_as_variant_tensor() to obtain the variant " 2317 "associated with a dataset") 2318 raise AttributeError("{}: A likely cause of this error is that the super " 2319 "call for this dataset is not the last line of the " 2320 "__init__ method. The base class causes the " 2321 "_as_variant_tensor call in its constructor and " 2322 "if that uses attributes defined in the __init__ " 2323 "method, those attrs need to be defined before the " 2324 "super call.".format(e)) 2325 super(DatasetV1, self).__init__(variant_tensor) 2326 2327 @abc.abstractmethod 2328 def _as_variant_tensor(self): 2329 """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset. 2330 2331 Returns: 2332 A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset. 2333 """ 2334 raise NotImplementedError("Dataset._as_variant_tensor") 2335 2336 @deprecation.deprecated( 2337 None, "This is a deprecated API that should only be used in TF 1 graph " 2338 "mode and legacy TF 2 graph mode available through `tf.compat.v1`. In " 2339 "all other situations -- namely, eager mode and inside `tf.function` -- " 2340 "you can consume dataset elements using `for elem in dataset: ...` or " 2341 "by explicitly creating iterator via `iterator = iter(dataset)` and " 2342 "fetching its elements via `values = next(iterator)`. Furthermore, " 2343 "this API is not available in TF 2. During the transition from TF 1 " 2344 "to TF 2 you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)` " 2345 "to create a TF 1 graph mode style iterator for a dataset created " 2346 "through TF 2 APIs. Note that this should be a transient state of your " 2347 "code base as there are in general no guarantees about the " 2348 "interoperability of TF 1 and TF 2 code.") 2349 def make_one_shot_iterator(self): 2350 """Creates an iterator for elements of this dataset. 2351 2352 Note: The returned iterator will be initialized automatically. 2353 A "one-shot" iterator does not currently support re-initialization. For 2354 that see `make_initializable_iterator`. 2355 2356 Example: 2357 2358 ```python 2359 # Building graph ... 2360 dataset = ... 2361 next_value = dataset.make_one_shot_iterator().get_next() 2362 2363 # ... from within a session ... 2364 try: 2365 while True: 2366 value = sess.run(next_value) 2367 ... 2368 except tf.errors.OutOfRangeError: 2369 pass 2370 ``` 2371 2372 Returns: 2373 An `tf.data.Iterator` for elements of this dataset. 2374 """ 2375 return self._make_one_shot_iterator() 2376 2377 def _make_one_shot_iterator(self): # pylint: disable=missing-docstring 2378 if context.executing_eagerly(): 2379 with ops.colocate_with(self._variant_tensor): 2380 return iterator_ops.OwnedIterator(self) 2381 2382 _ensure_same_dataset_graph(self) 2383 # Now that we create datasets at python object creation time, the capture 2384 # by value _make_dataset() function would try to capture these variant 2385 # tensor dataset inputs, which are marked as stateful ops and would throw 2386 # an error if we try and capture them. We therefore traverse the graph 2387 # to find all these ops and allowlist them so that the capturing 2388 # logic instead of throwing an error recreates these ops which is what was 2389 # happening before. 2390 all_ds_ops = traverse.obtain_all_variant_tensor_ops(self) 2391 graph_level_seed, op_level_seed = core_random_seed.get_seed(None) 2392 2393 # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is 2394 # a 0-argument function. 2395 @function.Defun(capture_by_value=True, allowlisted_stateful_ops=all_ds_ops) 2396 def _make_dataset(): 2397 """Factory function for a dataset.""" 2398 # NOTE(mrry): `Defun` does not capture the graph-level seed from the 2399 # enclosing graph, so if a graph-level seed is present we set the local 2400 # graph seed based on a combination of the graph- and op-level seeds. 2401 if graph_level_seed is not None: 2402 assert op_level_seed is not None 2403 core_random_seed.set_random_seed( 2404 (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1)) 2405 2406 dataset = self._apply_options() 2407 return dataset._variant_tensor # pylint: disable=protected-access 2408 2409 try: 2410 _make_dataset.add_to_graph(ops.get_default_graph()) 2411 except ValueError as err: 2412 if "Cannot capture a stateful node" in str(err): 2413 raise ValueError( 2414 "Failed to create a one-shot iterator for a dataset. " 2415 "`Dataset.make_one_shot_iterator()` does not support datasets that " 2416 "capture stateful objects, such as a `Variable` or `LookupTable`. " 2417 "In these cases, use `Dataset.make_initializable_iterator()`. " 2418 "(Original error: %s)" % err) 2419 else: 2420 six.reraise(ValueError, err) 2421 2422 with ops.colocate_with(self._variant_tensor): 2423 # pylint: disable=protected-access 2424 return iterator_ops.Iterator( 2425 gen_dataset_ops.one_shot_iterator( 2426 dataset_factory=_make_dataset, **self._flat_structure), None, 2427 get_legacy_output_types(self), get_legacy_output_shapes(self), 2428 get_legacy_output_classes(self)) 2429 2430 @deprecation.deprecated( 2431 None, "This is a deprecated API that should only be used in TF 1 graph " 2432 "mode and legacy TF 2 graph mode available through `tf.compat.v1`. " 2433 "In all other situations -- namely, eager mode and inside `tf.function` " 2434 "-- you can consume dataset elements using `for elem in dataset: ...` " 2435 "or by explicitly creating iterator via `iterator = iter(dataset)` " 2436 "and fetching its elements via `values = next(iterator)`. " 2437 "Furthermore, this API is not available in TF 2. During the transition " 2438 "from TF 1 to TF 2 you can use " 2439 "`tf.compat.v1.data.make_initializable_iterator(dataset)` to create a TF " 2440 "1 graph mode style iterator for a dataset created through TF 2 APIs. " 2441 "Note that this should be a transient state of your code base as there " 2442 "are in general no guarantees about the interoperability of TF 1 and TF " 2443 "2 code.") 2444 def make_initializable_iterator(self, shared_name=None): 2445 """Creates an iterator for elements of this dataset. 2446 2447 Note: The returned iterator will be in an uninitialized state, 2448 and you must run the `iterator.initializer` operation before using it: 2449 2450 ```python 2451 # Building graph ... 2452 dataset = ... 2453 iterator = dataset.make_initializable_iterator() 2454 next_value = iterator.get_next() # This is a Tensor. 2455 2456 # ... from within a session ... 2457 sess.run(iterator.initializer) 2458 try: 2459 while True: 2460 value = sess.run(next_value) 2461 ... 2462 except tf.errors.OutOfRangeError: 2463 pass 2464 ``` 2465 2466 Args: 2467 shared_name: (Optional.) If non-empty, the returned iterator will be 2468 shared under the given name across multiple sessions that share the same 2469 devices (e.g. when using a remote server). 2470 2471 Returns: 2472 A `tf.data.Iterator` for elements of this dataset. 2473 2474 Raises: 2475 RuntimeError: If eager execution is enabled. 2476 """ 2477 return self._make_initializable_iterator(shared_name) 2478 2479 def _make_initializable_iterator(self, shared_name=None): # pylint: disable=missing-docstring 2480 if context.executing_eagerly(): 2481 raise RuntimeError( 2482 "dataset.make_initializable_iterator is not supported when eager " 2483 "execution is enabled. Use `for element in dataset` instead.") 2484 _ensure_same_dataset_graph(self) 2485 dataset = self._apply_options() 2486 if shared_name is None: 2487 shared_name = "" 2488 2489 with ops.colocate_with(self._variant_tensor): 2490 iterator_resource = gen_dataset_ops.iterator_v2( 2491 container="", shared_name=shared_name, **self._flat_structure) 2492 2493 initializer = gen_dataset_ops.make_iterator( 2494 dataset._variant_tensor, # pylint: disable=protected-access 2495 iterator_resource) 2496 2497 # pylint: disable=protected-access 2498 return iterator_ops.Iterator(iterator_resource, initializer, 2499 get_legacy_output_types(dataset), 2500 get_legacy_output_shapes(dataset), 2501 get_legacy_output_classes(dataset)) 2502 2503 @property 2504 @deprecation.deprecated( 2505 None, "Use `tf.compat.v1.data.get_output_classes(dataset)`.") 2506 def output_classes(self): 2507 """Returns the class of each component of an element of this dataset. 2508 2509 Returns: 2510 A nested structure of Python `type` objects corresponding to each 2511 component of an element of this dataset. 2512 """ 2513 return nest.map_structure( 2514 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 2515 self.element_spec) 2516 2517 @property 2518 @deprecation.deprecated( 2519 None, "Use `tf.compat.v1.data.get_output_shapes(dataset)`.") 2520 def output_shapes(self): 2521 """Returns the shape of each component of an element of this dataset. 2522 2523 Returns: 2524 A nested structure of `tf.TensorShape` objects corresponding to each 2525 component of an element of this dataset. 2526 """ 2527 return nest.map_structure( 2528 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 2529 self.element_spec) 2530 2531 @property 2532 @deprecation.deprecated( 2533 None, "Use `tf.compat.v1.data.get_output_types(dataset)`.") 2534 def output_types(self): 2535 """Returns the type of each component of an element of this dataset. 2536 2537 Returns: 2538 A nested structure of `tf.DType` objects corresponding to each component 2539 of an element of this dataset. 2540 """ 2541 return nest.map_structure( 2542 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 2543 self.element_spec) 2544 2545 @property 2546 def element_spec(self): 2547 # TODO(b/110122868): Remove this override once all `Dataset` instances 2548 # implement `element_structure`. 2549 return structure.convert_legacy_structure( 2550 self.output_types, self.output_shapes, self.output_classes) 2551 2552 @staticmethod 2553 @functools.wraps(DatasetV2.from_tensors) 2554 def from_tensors(tensors): 2555 return DatasetV1Adapter(DatasetV2.from_tensors(tensors)) 2556 2557 @staticmethod 2558 @functools.wraps(DatasetV2.from_tensor_slices) 2559 def from_tensor_slices(tensors): 2560 return DatasetV1Adapter(DatasetV2.from_tensor_slices(tensors)) 2561 2562 @staticmethod 2563 @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") 2564 def from_sparse_tensor_slices(sparse_tensor): 2565 """Splits each rank-N `tf.sparse.SparseTensor` in this dataset row-wise. 2566 2567 Args: 2568 sparse_tensor: A `tf.sparse.SparseTensor`. 2569 2570 Returns: 2571 Dataset: A `Dataset` of rank-(N-1) sparse tensors. 2572 """ 2573 return DatasetV1Adapter(SparseTensorSliceDataset(sparse_tensor)) 2574 2575 @staticmethod 2576 @functools.wraps(DatasetV2.from_generator) 2577 def from_generator(generator, 2578 output_types=None, 2579 output_shapes=None, 2580 args=None, 2581 output_signature=None): 2582 return DatasetV1Adapter( 2583 DatasetV2.from_generator(generator, output_types, output_shapes, args, 2584 output_signature)) 2585 2586 @staticmethod 2587 @functools.wraps(DatasetV2.range) 2588 def range(*args, **kwargs): 2589 return DatasetV1Adapter(DatasetV2.range(*args, **kwargs)) 2590 2591 @staticmethod 2592 @functools.wraps(DatasetV2.zip) 2593 def zip(datasets): 2594 return DatasetV1Adapter(DatasetV2.zip(datasets)) 2595 2596 @functools.wraps(DatasetV2.concatenate) 2597 def concatenate(self, dataset): 2598 return DatasetV1Adapter(super(DatasetV1, self).concatenate(dataset)) 2599 2600 @functools.wraps(DatasetV2.prefetch) 2601 def prefetch(self, buffer_size): 2602 return DatasetV1Adapter(super(DatasetV1, self).prefetch(buffer_size)) 2603 2604 @staticmethod 2605 @functools.wraps(DatasetV2.list_files) 2606 def list_files(file_pattern, shuffle=None, seed=None): 2607 return DatasetV1Adapter(DatasetV2.list_files(file_pattern, shuffle, seed)) 2608 2609 @functools.wraps(DatasetV2.repeat) 2610 def repeat(self, count=None): 2611 return DatasetV1Adapter(super(DatasetV1, self).repeat(count)) 2612 2613 @functools.wraps(DatasetV2.shuffle) 2614 def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None): 2615 return DatasetV1Adapter(super(DatasetV1, self).shuffle( 2616 buffer_size, seed, reshuffle_each_iteration)) 2617 2618 @functools.wraps(DatasetV2.cache) 2619 def cache(self, filename=""): 2620 return DatasetV1Adapter(super(DatasetV1, self).cache(filename)) 2621 2622 @functools.wraps(DatasetV2.take) 2623 def take(self, count): 2624 return DatasetV1Adapter(super(DatasetV1, self).take(count)) 2625 2626 @functools.wraps(DatasetV2.skip) 2627 def skip(self, count): 2628 return DatasetV1Adapter(super(DatasetV1, self).skip(count)) 2629 2630 @functools.wraps(DatasetV2.shard) 2631 def shard(self, num_shards, index): 2632 return DatasetV1Adapter(super(DatasetV1, self).shard(num_shards, index)) 2633 2634 @functools.wraps(DatasetV2.batch) 2635 def batch(self, batch_size, drop_remainder=False, num_parallel_calls=None): 2636 return DatasetV1Adapter( 2637 super(DatasetV1, self).batch(batch_size, drop_remainder, 2638 num_parallel_calls)) 2639 2640 @functools.wraps(DatasetV2.padded_batch) 2641 def padded_batch(self, 2642 batch_size, 2643 padded_shapes=None, 2644 padding_values=None, 2645 drop_remainder=False): 2646 return DatasetV1Adapter( 2647 super(DatasetV1, self).padded_batch(batch_size, padded_shapes, 2648 padding_values, drop_remainder)) 2649 2650 @functools.wraps(DatasetV2.map) 2651 def map(self, map_func, num_parallel_calls=None, deterministic=None): 2652 if num_parallel_calls is None: 2653 return DatasetV1Adapter( 2654 MapDataset(self, map_func, preserve_cardinality=False)) 2655 else: 2656 return DatasetV1Adapter( 2657 ParallelMapDataset( 2658 self, 2659 map_func, 2660 num_parallel_calls, 2661 deterministic, 2662 preserve_cardinality=False)) 2663 2664 @deprecation.deprecated(None, "Use `tf.data.Dataset.map()") 2665 def map_with_legacy_function(self, 2666 map_func, 2667 num_parallel_calls=None, 2668 deterministic=None): 2669 """Maps `map_func` across the elements of this dataset. 2670 2671 Note: This is an escape hatch for existing uses of `map` that do not work 2672 with V2 functions. New uses are strongly discouraged and existing uses 2673 should migrate to `map` as this method will be removed in V2. 2674 2675 Args: 2676 map_func: A function mapping a nested structure of tensors (having shapes 2677 and types defined by `self.output_shapes` and `self.output_types`) to 2678 another nested structure of tensors. 2679 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 2680 representing the number elements to process asynchronously in parallel. 2681 If not specified, elements will be processed sequentially. If the value 2682 `tf.data.AUTOTUNE` is used, then the number of parallel 2683 calls is set dynamically based on available CPU. 2684 deterministic: (Optional.) When `num_parallel_calls` is specified, this 2685 boolean controls the order in which the transformation produces 2686 elements. If set to `False`, the transformation is allowed to yield 2687 elements out of order to trade determinism for performance. If not 2688 specified, the `tf.data.Options.experimental_deterministic` option 2689 (`True` by default) controls the behavior. 2690 2691 Returns: 2692 Dataset: A `Dataset`. 2693 """ 2694 if num_parallel_calls is None: 2695 if deterministic is not None: 2696 warnings.warn("The `deterministic` argument has no effect unless the " 2697 "`num_parallel_calls` argument is specified.") 2698 return DatasetV1Adapter( 2699 MapDataset( 2700 self, 2701 map_func, 2702 preserve_cardinality=False, 2703 use_legacy_function=True)) 2704 else: 2705 return DatasetV1Adapter( 2706 ParallelMapDataset( 2707 self, 2708 map_func, 2709 num_parallel_calls, 2710 deterministic, 2711 preserve_cardinality=False, 2712 use_legacy_function=True)) 2713 2714 @functools.wraps(DatasetV2.flat_map) 2715 def flat_map(self, map_func): 2716 return DatasetV1Adapter(super(DatasetV1, self).flat_map(map_func)) 2717 2718 @functools.wraps(DatasetV2.interleave) 2719 def interleave(self, 2720 map_func, 2721 cycle_length=None, 2722 block_length=None, 2723 num_parallel_calls=None, 2724 deterministic=None): 2725 return DatasetV1Adapter( 2726 super(DatasetV1, self).interleave(map_func, cycle_length, block_length, 2727 num_parallel_calls, deterministic)) 2728 2729 @functools.wraps(DatasetV2.filter) 2730 def filter(self, predicate): 2731 return DatasetV1Adapter(super(DatasetV1, self).filter(predicate)) 2732 2733 @deprecation.deprecated(None, "Use `tf.data.Dataset.filter()") 2734 def filter_with_legacy_function(self, predicate): 2735 """Filters this dataset according to `predicate`. 2736 2737 Note: This is an escape hatch for existing uses of `filter` that do not work 2738 with V2 functions. New uses are strongly discouraged and existing uses 2739 should migrate to `filter` as this method will be removed in V2. 2740 2741 Args: 2742 predicate: A function mapping a nested structure of tensors (having shapes 2743 and types defined by `self.output_shapes` and `self.output_types`) to a 2744 scalar `tf.bool` tensor. 2745 2746 Returns: 2747 Dataset: The `Dataset` containing the elements of this dataset for which 2748 `predicate` is `True`. 2749 """ 2750 return FilterDataset(self, predicate, use_legacy_function=True) 2751 2752 @functools.wraps(DatasetV2.apply) 2753 def apply(self, transformation_func): 2754 return DatasetV1Adapter(super(DatasetV1, self).apply(transformation_func)) 2755 2756 @functools.wraps(DatasetV2.window) 2757 def window(self, size, shift=None, stride=1, drop_remainder=False): 2758 return DatasetV1Adapter(super(DatasetV1, self).window( 2759 size, shift, stride, drop_remainder)) 2760 2761 @functools.wraps(DatasetV2.unbatch) 2762 def unbatch(self): 2763 return DatasetV1Adapter(super(DatasetV1, self).unbatch()) 2764 2765 @functools.wraps(DatasetV2.with_options) 2766 def with_options(self, options): 2767 return DatasetV1Adapter(super(DatasetV1, self).with_options(options)) 2768 2769 2770if tf2.enabled(): 2771 Dataset = DatasetV2 2772else: 2773 Dataset = DatasetV1 2774 2775 2776class DatasetV1Adapter(DatasetV1): 2777 """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API.""" 2778 2779 def __init__(self, dataset): 2780 self._dataset = dataset 2781 super(DatasetV1Adapter, self).__init__() 2782 2783 def _as_variant_tensor(self): 2784 return self._dataset._variant_tensor # pylint: disable=protected-access 2785 2786 def _has_captured_ref(self): 2787 return self._dataset._has_captured_ref() # pylint: disable=protected-access 2788 2789 def _inputs(self): 2790 return self._dataset._inputs() # pylint: disable=protected-access 2791 2792 def _functions(self): 2793 return self._dataset._functions() # pylint: disable=protected-access 2794 2795 def options(self): 2796 return self._dataset.options() 2797 2798 @property 2799 def element_spec(self): 2800 return self._dataset.element_spec # pylint: disable=protected-access 2801 2802 def __iter__(self): 2803 return iter(self._dataset) 2804 2805 2806def _ensure_same_dataset_graph(dataset): 2807 """Walks the dataset graph to ensure all datasets come from the same graph.""" 2808 # pylint: disable=protected-access 2809 current_graph = ops.get_default_graph() 2810 bfs_q = Queue.Queue() 2811 bfs_q.put(dataset) 2812 visited = [] 2813 while not bfs_q.empty(): 2814 ds = bfs_q.get() 2815 visited.append(ds) 2816 ds_graph = ds._graph 2817 if current_graph != ds_graph: 2818 raise ValueError( 2819 "The graph (" + str(current_graph) + ") of the iterator is different " 2820 "from the graph (" + str(ds_graph) + ") the dataset: " + 2821 str(ds._variant_tensor) + " was created in. If you are using the " 2822 "Estimator API, make sure that no part of the dataset returned by " 2823 "the `input_fn` function is defined outside the `input_fn` function. " 2824 "Please ensure that all datasets in the pipeline are created in the " 2825 "same graph as the iterator.") 2826 for input_ds in ds._inputs(): 2827 if input_ds not in visited: 2828 bfs_q.put(input_ds) 2829 2830 2831@tf_export(v1=["data.make_one_shot_iterator"]) 2832def make_one_shot_iterator(dataset): 2833 """Creates an iterator for elements of `dataset`. 2834 2835 Note: The returned iterator will be initialized automatically. 2836 A "one-shot" iterator does not support re-initialization. 2837 2838 Args: 2839 dataset: A `tf.data.Dataset`. 2840 2841 Returns: 2842 A `tf.data.Iterator` for elements of `dataset`. 2843 """ 2844 try: 2845 # Call the defined `_make_one_shot_iterator()` if there is one, because some 2846 # datasets (e.g. for prefetching) override its behavior. 2847 return dataset._make_one_shot_iterator() # pylint: disable=protected-access 2848 except AttributeError: 2849 return DatasetV1Adapter(dataset)._make_one_shot_iterator() # pylint: disable=protected-access 2850 2851 2852@tf_export(v1=["data.make_initializable_iterator"]) 2853def make_initializable_iterator(dataset, shared_name=None): 2854 """Creates an iterator for elements of `dataset`. 2855 2856 Note: The returned iterator will be in an uninitialized state, 2857 and you must run the `iterator.initializer` operation before using it: 2858 2859 ```python 2860 dataset = ... 2861 iterator = tf.compat.v1.data.make_initializable_iterator(dataset) 2862 # ... 2863 sess.run(iterator.initializer) 2864 ``` 2865 2866 Args: 2867 dataset: A `tf.data.Dataset`. 2868 shared_name: (Optional.) If non-empty, the returned iterator will be shared 2869 under the given name across multiple sessions that share the same devices 2870 (e.g. when using a remote server). 2871 2872 Returns: 2873 A `tf.data.Iterator` for elements of `dataset`. 2874 2875 Raises: 2876 RuntimeError: If eager execution is enabled. 2877 """ 2878 try: 2879 # Call the defined `_make_initializable_iterator()` if there is one, because 2880 # some datasets (e.g. for prefetching) override its behavior. 2881 return dataset._make_initializable_iterator(shared_name) # pylint: disable=protected-access 2882 except AttributeError: 2883 return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name) # pylint: disable=protected-access 2884 2885 2886@tf_export("data.experimental.get_structure") 2887def get_structure(dataset_or_iterator): 2888 """Returns the type signature for elements of the input dataset / iterator. 2889 2890 Args: 2891 dataset_or_iterator: A `tf.data.Dataset` or an `tf.data.Iterator`. 2892 2893 Returns: 2894 A nested structure of `tf.TypeSpec` objects matching the structure of an 2895 element of `dataset_or_iterator` and specifying the type of individual 2896 components. 2897 2898 Raises: 2899 TypeError: If input is not a `tf.data.Dataset` or an `tf.data.Iterator` 2900 object. 2901 """ 2902 try: 2903 return dataset_or_iterator.element_spec # pylint: disable=protected-access 2904 except AttributeError: 2905 raise TypeError("`dataset_or_iterator` must be a `tf.data.Dataset` or " 2906 "tf.data.Iterator object, but got %s." % 2907 type(dataset_or_iterator)) 2908 2909 2910@tf_export(v1=["data.get_output_classes"]) 2911def get_legacy_output_classes(dataset_or_iterator): 2912 """Returns the output classes for elements of the input dataset / iterator. 2913 2914 Args: 2915 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 2916 2917 Returns: 2918 A nested structure of Python `type` objects matching the structure of the 2919 dataset / iterator elements and specifying the class of the individual 2920 components. 2921 """ 2922 return nest.map_structure( 2923 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 2924 get_structure(dataset_or_iterator)) 2925 2926 2927@tf_export(v1=["data.get_output_shapes"]) 2928def get_legacy_output_shapes(dataset_or_iterator): 2929 """Returns the output shapes for elements of the input dataset / iterator. 2930 2931 Args: 2932 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 2933 2934 Returns: 2935 A nested structure of `tf.TensorShape` objects matching the structure of 2936 the dataset / iterator elements and specifying the shape of the individual 2937 components. 2938 """ 2939 return nest.map_structure( 2940 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 2941 get_structure(dataset_or_iterator)) 2942 2943 2944@tf_export(v1=["data.get_output_types"]) 2945def get_legacy_output_types(dataset_or_iterator): 2946 """Returns the output shapes for elements of the input dataset / iterator. 2947 2948 Args: 2949 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 2950 2951 Returns: 2952 A nested structure of `tf.DType` objects matching the structure of 2953 dataset / iterator elements and specifying the shape of the individual 2954 components. 2955 """ 2956 return nest.map_structure( 2957 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 2958 get_structure(dataset_or_iterator)) 2959 2960 2961@tf_export("data.Options") 2962class Options(options_lib.OptionsBase): 2963 """Represents options for `tf.data.Dataset`. 2964 2965 A `tf.data.Options` object can be, for instance, used to control which static 2966 optimizations to apply to the input pipeline graph or whether to use 2967 performance modeling to dynamically tune the parallelism of operations such as 2968 `tf.data.Dataset.map` or `tf.data.Dataset.interleave`. 2969 2970 The options are set for the entire dataset and are carried over to datasets 2971 created through tf.data transformations. 2972 2973 The options can be set either by mutating the object returned by 2974 `tf.data.Dataset.options()` or by constructing an `Options` object and using 2975 the `tf.data.Dataset.with_options(options)` transformation, which returns a 2976 dataset with the options set. 2977 2978 >>> dataset = tf.data.Dataset.range(42) 2979 >>> dataset.options().experimental_deterministic = False 2980 >>> print(dataset.options().experimental_deterministic) 2981 False 2982 2983 >>> dataset = tf.data.Dataset.range(42) 2984 >>> options = tf.data.Options() 2985 >>> options.experimental_deterministic = False 2986 >>> dataset = dataset.with_options(options) 2987 >>> print(dataset.options().experimental_deterministic) 2988 False 2989 2990 Note: A known limitation of the `tf.data.Options` implementation is that the 2991 options are not preserved across tf.function boundaries. In particular, to 2992 set options for a dataset that is iterated within a tf.function, the options 2993 need to be set within the same tf.function. 2994 """ 2995 2996 experimental_deterministic = options_lib.create_option( 2997 name="experimental_deterministic", 2998 ty=bool, 2999 docstring= 3000 "Whether the outputs need to be produced in deterministic order. If None," 3001 " defaults to True.") 3002 3003 experimental_distribute = options_lib.create_option( 3004 name="experimental_distribute", 3005 ty=distribute_options.DistributeOptions, 3006 docstring= 3007 "The distribution strategy options associated with the dataset. See " 3008 "`tf.data.experimental.DistributeOptions` for more details.", 3009 default_factory=distribute_options.DistributeOptions) 3010 3011 experimental_optimization = options_lib.create_option( 3012 name="experimental_optimization", 3013 ty=optimization_options.OptimizationOptions, 3014 docstring= 3015 "The optimization options associated with the dataset. See " 3016 "`tf.data.experimental.OptimizationOptions` for more details.", 3017 default_factory=optimization_options.OptimizationOptions) 3018 3019 experimental_slack = options_lib.create_option( 3020 name="experimental_slack", 3021 ty=bool, 3022 docstring="Whether to introduce 'slack' in the last `prefetch` of the " 3023 "input pipeline, if it exists. This may reduce CPU contention with " 3024 "accelerator host-side activity at the start of a step. The slack " 3025 "frequency is determined by the number of devices attached to this " 3026 "input pipeline. If None, defaults to False.") 3027 3028 experimental_stats = options_lib.create_option( 3029 name="experimental_stats", 3030 ty=stats_options.StatsOptions, 3031 docstring= 3032 "The statistics options associated with the dataset. See " 3033 "`tf.data.experimental.StatsOptions` for more details.", 3034 default_factory=stats_options.StatsOptions) 3035 3036 experimental_threading = options_lib.create_option( 3037 name="experimental_threading", 3038 ty=threading_options.ThreadingOptions, 3039 docstring= 3040 "The threading options associated with the dataset. See " 3041 "`tf.data.experimental.ThreadingOptions` for more details.", 3042 default_factory=threading_options.ThreadingOptions) 3043 3044 experimental_external_state_policy = options_lib.create_option( 3045 name="experimental_external_state_policy", 3046 ty=distribute_options.ExternalStatePolicy, 3047 docstring="This option can be used to override the default policy for " 3048 "how to handle external state when serializing a dataset or " 3049 "checkpointing its iterator. There are three settings available - " 3050 "IGNORE: External state is ignored without a warning; WARN: External " 3051 "state is ignored and a warning is logged; FAIL: External state results " 3052 "in an error.") 3053 3054 def _to_proto(self): 3055 pb = dataset_options_pb2.Options() 3056 if self.experimental_deterministic is not None: 3057 pb.deterministic = self.experimental_deterministic 3058 pb.distribute_options.CopyFrom(self.experimental_distribute._to_proto()) # pylint: disable=protected-access 3059 if self.experimental_external_state_policy is not None: 3060 pb.external_state_policy = ( 3061 distribute_options.ExternalStatePolicy._to_proto( # pylint: disable=protected-access 3062 self.experimental_external_state_policy)) 3063 pb.optimization_options.CopyFrom(self.experimental_optimization._to_proto()) # pylint: disable=protected-access 3064 if self.experimental_slack is not None: 3065 pb.slack = self.experimental_slack 3066 pb.threading_options.CopyFrom(self.experimental_threading._to_proto()) # pylint: disable=protected-access 3067 return pb 3068 3069 def _from_proto(self, pb): 3070 if pb.WhichOneof("optional_deterministic") is not None: 3071 self.experimental_deterministic = pb.deterministic 3072 self.experimental_distribute._from_proto(pb.distribute_options) # pylint: disable=protected-access 3073 if pb.WhichOneof("optional_external_state_policy") is not None: 3074 self.experimental_external_state_policy = ( 3075 distribute_options.ExternalStatePolicy._from_proto( # pylint: disable=protected-access 3076 pb.external_state_policy)) 3077 self.experimental_optimization._from_proto(pb.optimization_options) # pylint: disable=protected-access 3078 if pb.WhichOneof("optional_slack") is not None: 3079 self.experimental_slack = pb.slack 3080 self.experimental_threading._from_proto(pb.threading_options) # pylint: disable=protected-access 3081 3082 def _graph_rewrites(self): 3083 """Produces lists of enabled, disabled, default static graph rewrites. 3084 3085 Returns: 3086 result: a namedtuple with three attributes. `result.enabled` is the list 3087 of user enabled graph rewrites. `result.disabled` is the list of user 3088 disabled graph rewrites. `result.default` is the list of graph 3089 rewrites that are enabled by default (the user has not explicitly 3090 enabled or disabled them). 3091 """ 3092 if self.experimental_optimization is not None: 3093 result = self.experimental_optimization._graph_rewrites() # pylint: disable=protected-access 3094 else: 3095 # Apply default options 3096 result = optimization_options.OptimizationOptions()._graph_rewrites() # pylint: disable=protected-access 3097 3098 if self.experimental_deterministic is False: # pylint: disable=g-bool-id-comparison 3099 result.enabled.append("make_sloppy") 3100 elif self.experimental_deterministic is True: # pylint: disable=g-bool-id-comparison 3101 result.disabled.append("make_sloppy") 3102 if self.experimental_stats: 3103 if self.experimental_stats.latency_all_edges is True: # pylint: disable=g-bool-id-comparison 3104 result.enabled.append("latency_all_edges") 3105 elif self.experimental_stats.latency_all_edges is False: # pylint: disable=g-bool-id-comparison 3106 result.disabled.append("latency_all_edges") 3107 if self.experimental_slack is True: # pylint: disable=g-bool-id-comparison 3108 result.enabled.append("slack") 3109 elif self.experimental_slack is False: # pylint: disable=g-bool-id-comparison 3110 result.disabled.append("slack") 3111 3112 graph_rewrites = options_lib.graph_rewrites() 3113 return graph_rewrites(enabled=list(set(result.enabled)), 3114 disabled=list(set(result.disabled)), 3115 default=list(set(result.default))) 3116 3117 def _graph_rewrite_configs(self, autotune): 3118 """Produces the list of configurations for enabled graph optimizations.""" 3119 result = [] 3120 if self.experimental_optimization: 3121 result.extend( 3122 self.experimental_optimization._graph_rewrite_configs(autotune)) # pylint: disable=protected-access 3123 3124 if self.experimental_slack: 3125 num_devices = self.experimental_distribute.num_devices 3126 if num_devices is None: 3127 num_devices = 1 3128 result.append("slack:slack_period:%d" % num_devices) 3129 return result 3130 3131 def _autotune_settings(self): 3132 if self.experimental_optimization is not None: 3133 return self.experimental_optimization._autotune_settings() # pylint: disable=protected-access 3134 3135 # Return default autotune options 3136 return optimization_options.OptimizationOptions()._autotune_settings() # pylint: disable=protected-access 3137 3138 def merge(self, options): 3139 """Merges itself with the given `tf.data.Options`. 3140 3141 If this object and the `options` to merge set an option differently, a 3142 warning is generated and this object's value is updated with the `options` 3143 object's value. 3144 3145 Args: 3146 options: a `tf.data.Options` to merge with 3147 3148 Returns: 3149 New `tf.data.Options` object which is the result of merging self with 3150 the input `tf.data.Options`. 3151 """ 3152 return options_lib.merge_options(self, options) 3153 3154 3155class DatasetSource(DatasetV2): 3156 """Abstract class representing a dataset with no inputs.""" 3157 3158 def _inputs(self): 3159 return [] 3160 3161 3162class UnaryDataset(DatasetV2): 3163 """Abstract class representing a dataset with one input.""" 3164 3165 def __init__(self, input_dataset, variant_tensor): 3166 self._input_dataset = input_dataset 3167 super(UnaryDataset, self).__init__(variant_tensor) 3168 3169 def _inputs(self): 3170 return [self._input_dataset] 3171 3172 3173class UnaryUnchangedStructureDataset(UnaryDataset): 3174 """Represents a unary dataset with the same input and output structure.""" 3175 3176 def __init__(self, input_dataset, variant_tensor): 3177 self._input_dataset = input_dataset 3178 super(UnaryUnchangedStructureDataset, self).__init__( 3179 input_dataset, variant_tensor) 3180 3181 @property 3182 def element_spec(self): 3183 return self._input_dataset.element_spec 3184 3185 3186class TensorDataset(DatasetSource): 3187 """A `Dataset` with a single element.""" 3188 3189 def __init__(self, element): 3190 """See `Dataset.from_tensors()` for details.""" 3191 element = structure.normalize_element(element) 3192 self._structure = structure.type_spec_from_value(element) 3193 self._tensors = structure.to_tensor_list(self._structure, element) 3194 3195 variant_tensor = gen_dataset_ops.tensor_dataset( 3196 self._tensors, 3197 output_shapes=structure.get_flat_tensor_shapes(self._structure)) 3198 super(TensorDataset, self).__init__(variant_tensor) 3199 3200 @property 3201 def element_spec(self): 3202 return self._structure 3203 3204 3205class TensorSliceDataset(DatasetSource): 3206 """A `Dataset` of slices from a dataset element.""" 3207 3208 def __init__(self, element): 3209 """See `Dataset.from_tensor_slices()` for details.""" 3210 element = structure.normalize_element(element) 3211 batched_spec = structure.type_spec_from_value(element) 3212 self._tensors = structure.to_batched_tensor_list(batched_spec, element) 3213 self._structure = nest.map_structure( 3214 lambda component_spec: component_spec._unbatch(), batched_spec) # pylint: disable=protected-access 3215 3216 batch_dim = tensor_shape.Dimension(tensor_shape.dimension_value( 3217 self._tensors[0].get_shape()[0])) 3218 for t in self._tensors[1:]: 3219 batch_dim.assert_is_compatible_with(tensor_shape.Dimension( 3220 tensor_shape.dimension_value(t.get_shape()[0]))) 3221 3222 variant_tensor = gen_dataset_ops.tensor_slice_dataset( 3223 self._tensors, 3224 output_shapes=structure.get_flat_tensor_shapes(self._structure)) 3225 super(TensorSliceDataset, self).__init__(variant_tensor) 3226 3227 @property 3228 def element_spec(self): 3229 return self._structure 3230 3231 3232class SparseTensorSliceDataset(DatasetSource): 3233 """A `Dataset` that splits a rank-N `tf.sparse.SparseTensor` into its rows.""" 3234 3235 def __init__(self, sparse_tensor): 3236 """See `Dataset.from_sparse_tensor_slices()` for details.""" 3237 if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor): 3238 raise TypeError( 3239 "`sparse_tensor` must be a `tf.sparse.SparseTensor` object." 3240 "Was {}.".format(sparse_tensor)) 3241 self._sparse_tensor = sparse_tensor 3242 3243 indices_shape = self._sparse_tensor.indices.get_shape() 3244 shape_shape = self._sparse_tensor.dense_shape.get_shape() 3245 rank = (indices_shape.dims[1] - 1).merge_with(shape_shape.dims[0] - 1) 3246 self._structure = (tensor_spec.TensorSpec([None, rank], dtypes.int64), 3247 tensor_spec.TensorSpec([None], 3248 self._sparse_tensor.dtype), 3249 tensor_spec.TensorSpec([rank], dtypes.int64)) 3250 3251 variant_tensor = gen_dataset_ops.sparse_tensor_slice_dataset( 3252 self._sparse_tensor.indices, self._sparse_tensor.values, 3253 self._sparse_tensor.dense_shape) 3254 super(SparseTensorSliceDataset, self).__init__(variant_tensor) 3255 3256 @property 3257 def element_spec(self): 3258 return self._structure 3259 3260 3261class _VariantDataset(DatasetV2): 3262 """A Dataset wrapper around a `tf.variant`-typed function argument.""" 3263 3264 def __init__(self, dataset_variant, structure): 3265 self._structure = structure 3266 super(_VariantDataset, self).__init__(dataset_variant) 3267 3268 def _inputs(self): 3269 return [] 3270 3271 @property 3272 def element_spec(self): 3273 return self._structure 3274 3275 3276class _NestedVariant(composite_tensor.CompositeTensor): 3277 3278 def __init__(self, variant_tensor, element_spec, dataset_shape): 3279 self._variant_tensor = variant_tensor 3280 self._element_spec = element_spec 3281 self._dataset_shape = dataset_shape 3282 3283 @property 3284 def _type_spec(self): 3285 return DatasetSpec(self._element_spec, self._dataset_shape) 3286 3287 3288@tf_export("data.experimental.from_variant") 3289def from_variant(variant, structure): 3290 """Constructs a dataset from the given variant and structure. 3291 3292 Args: 3293 variant: A scalar `tf.variant` tensor representing a dataset. 3294 structure: A `tf.data.experimental.Structure` object representing the 3295 structure of each element in the dataset. 3296 3297 Returns: 3298 A `tf.data.Dataset` instance. 3299 """ 3300 return _VariantDataset(variant, structure) # pylint: disable=protected-access 3301 3302 3303@tf_export("data.experimental.to_variant") 3304def to_variant(dataset): 3305 """Returns a variant representing the given dataset. 3306 3307 Args: 3308 dataset: A `tf.data.Dataset`. 3309 3310 Returns: 3311 A scalar `tf.variant` tensor representing the given dataset. 3312 """ 3313 return dataset._variant_tensor # pylint: disable=protected-access 3314 3315 3316@tf_export( 3317 "data.DatasetSpec", 3318 v1=["data.DatasetSpec", "data.experimental.DatasetStructure"]) 3319class DatasetSpec(type_spec.BatchableTypeSpec): 3320 """Type specification for `tf.data.Dataset`. 3321 3322 See `tf.TypeSpec` for more information about TensorFlow type specifications. 3323 3324 >>> dataset = tf.data.Dataset.range(3) 3325 >>> tf.data.DatasetSpec.from_value(dataset) 3326 DatasetSpec(TensorSpec(shape=(), dtype=tf.int64, name=None), TensorShape([])) 3327 """ 3328 3329 __slots__ = ["_element_spec", "_dataset_shape"] 3330 3331 def __init__(self, element_spec, dataset_shape=()): 3332 self._element_spec = element_spec 3333 self._dataset_shape = tensor_shape.as_shape(dataset_shape) 3334 3335 @property 3336 def value_type(self): 3337 return Dataset 3338 3339 @property 3340 def element_spec(self): 3341 """The inner element spec.""" 3342 return self._element_spec 3343 3344 def _serialize(self): 3345 return (self._element_spec, self._dataset_shape) 3346 3347 @property 3348 def _component_specs(self): 3349 return tensor_spec.TensorSpec(self._dataset_shape, dtypes.variant) 3350 3351 def _to_components(self, value): 3352 return value._variant_tensor # pylint: disable=protected-access 3353 3354 def _from_components(self, components): 3355 # pylint: disable=protected-access 3356 if self._dataset_shape.ndims == 0: 3357 return _VariantDataset(components, self._element_spec) 3358 else: 3359 return _NestedVariant(components, self._element_spec, self._dataset_shape) 3360 3361 def _to_tensor_list(self, value): 3362 return [ 3363 ops.convert_to_tensor( 3364 tf_nest.map_structure(lambda x: x._variant_tensor, value)) # pylint: disable=protected-access 3365 ] 3366 3367 @staticmethod 3368 def from_value(value): 3369 """Creates a `DatasetSpec` for the given `tf.data.Dataset` value.""" 3370 return DatasetSpec(value.element_spec) # pylint: disable=protected-access 3371 3372 def _batch(self, batch_size): 3373 return DatasetSpec( 3374 self._element_spec, 3375 tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape)) 3376 3377 def _unbatch(self): 3378 if self._dataset_shape.ndims == 0: 3379 raise ValueError("Unbatching a dataset is only supported for rank >= 1") 3380 return DatasetSpec(self._element_spec, self._dataset_shape[1:]) 3381 3382 def _to_batched_tensor_list(self, value): 3383 if self._dataset_shape.ndims == 0: 3384 raise ValueError("Unbatching a dataset is only supported for rank >= 1") 3385 return self._to_tensor_list(value) 3386 3387 def _to_legacy_output_types(self): 3388 return self 3389 3390 def _to_legacy_output_shapes(self): 3391 return self 3392 3393 def _to_legacy_output_classes(self): 3394 return self 3395 3396 3397class StructuredFunctionWrapper(object): 3398 """A function wrapper that supports structured arguments and return values.""" 3399 3400 def __init__(self, 3401 func, 3402 transformation_name, 3403 dataset=None, 3404 input_classes=None, 3405 input_shapes=None, 3406 input_types=None, 3407 input_structure=None, 3408 add_to_graph=True, 3409 use_legacy_function=False, 3410 defun_kwargs=None): 3411 """Creates a new `StructuredFunctionWrapper` for the given function. 3412 3413 Args: 3414 func: A function from a nested structure to another nested structure. 3415 transformation_name: Human-readable name of the transformation in which 3416 this function is being instantiated, for error messages. 3417 dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this 3418 dataset will be assumed as the structure for `func` arguments; otherwise 3419 `input_classes`, `input_shapes`, and `input_types` must be defined. 3420 input_classes: (Optional.) A nested structure of `type`. If given, this 3421 argument defines the Python types for `func` arguments. 3422 input_shapes: (Optional.) A nested structure of `tf.TensorShape`. If 3423 given, this argument defines the shapes and structure for `func` 3424 arguments. 3425 input_types: (Optional.) A nested structure of `tf.DType`. If given, this 3426 argument defines the element types and structure for `func` arguments. 3427 input_structure: (Optional.) A `Structure` object. If given, this argument 3428 defines the element types and structure for `func` arguments. 3429 add_to_graph: (Optional.) If `True`, the function will be added to the 3430 default graph, if it exists. 3431 use_legacy_function: (Optional.) A boolean that determines whether the 3432 function be created using `tensorflow.python.eager.function.defun` 3433 (default behavior) or `tensorflow.python.framework.function.Defun` 3434 (legacy behavior). 3435 defun_kwargs: (Optional.) A dictionary mapping string argument names to 3436 values. If supplied, will be passed to `function` as keyword arguments. 3437 3438 Raises: 3439 ValueError: If an invalid combination of `dataset`, `input_classes`, 3440 `input_shapes`, and `input_types` is passed. 3441 """ 3442 # pylint: disable=protected-access 3443 if input_structure is None: 3444 if dataset is None: 3445 if input_classes is None or input_shapes is None or input_types is None: 3446 raise ValueError("Either `dataset`, `input_structure` or all of " 3447 "`input_classes`, `input_shapes`, and `input_types` " 3448 "must be specified.") 3449 self._input_structure = structure.convert_legacy_structure( 3450 input_types, input_shapes, input_classes) 3451 else: 3452 if not (input_classes is None and input_shapes is None and 3453 input_types is None): 3454 raise ValueError("Either `dataset`, `input_structure` or all of " 3455 "`input_classes`, `input_shapes`, and `input_types` " 3456 "must be specified.") 3457 self._input_structure = dataset.element_spec 3458 else: 3459 if not (dataset is None and input_classes is None and input_shapes is None 3460 and input_types is None): 3461 raise ValueError("Either `dataset`, `input_structure`, or all of " 3462 "`input_classes`, `input_shapes`, and `input_types` " 3463 "must be specified.") 3464 self._input_structure = input_structure 3465 3466 self._func = func 3467 3468 # There is no graph to add in eager mode. 3469 add_to_graph &= not context.executing_eagerly() 3470 # There are some lifetime issues when a legacy function is not added to a 3471 # out-living graph. It's already deprecated so de-prioritizing the fix. 3472 add_to_graph |= use_legacy_function 3473 3474 if defun_kwargs is None: 3475 defun_kwargs = {} 3476 3477 readable_transformation_name = transformation_name.replace( 3478 ".", "_")[:-2] if len(transformation_name) > 2 else "" 3479 3480 func_name = "_".join( 3481 [readable_transformation_name, 3482 function_utils.get_func_name(func)]) 3483 # Sanitize function name to remove symbols that interfere with graph 3484 # construction. 3485 for symbol in ["<", ">", "\\", "'", " "]: 3486 func_name = func_name.replace(symbol, "") 3487 3488 ag_ctx = autograph_ctx.control_status_ctx() 3489 3490 def _warn_if_collections(transformation_name): 3491 """Prints a warning if the given graph uses common graph collections. 3492 3493 NOTE(mrry): Currently a warning is only generated for resources. Any 3494 variables created will be automatically hoisted out to the outermost scope 3495 using `init_scope()`. Some collections (such as for control-flow contexts) 3496 are benign and should not generate a warning. 3497 3498 Args: 3499 transformation_name: A human-readable name for the transformation. 3500 """ 3501 warnings.warn("Creating resources inside a function passed to %s " 3502 "is not supported. Create each resource outside the " 3503 "function, and capture it inside the function to use it." % 3504 transformation_name, stacklevel=5) 3505 3506 def _wrapper_helper(*args): 3507 """Wrapper for passing nested structures to and from tf.data functions.""" 3508 nested_args = structure.from_compatible_tensor_list( 3509 self._input_structure, args) 3510 if not _should_unpack_args(nested_args): 3511 nested_args = (nested_args,) 3512 3513 ret = autograph.tf_convert(func, ag_ctx)(*nested_args) 3514 # If `func` returns a list of tensors, `nest.flatten()` and 3515 # `ops.convert_to_tensor()` would conspire to attempt to stack 3516 # those tensors into a single tensor, because the customized 3517 # version of `nest.flatten()` does not recurse into lists. Since 3518 # it is more likely that the list arose from returning the 3519 # result of an operation (such as `tf.numpy_function()`) that returns a 3520 # list of not-necessarily-stackable tensors, we treat the 3521 # returned value is a `tuple` instead. A user wishing to pack 3522 # the return value into a single tensor can use an explicit 3523 # `tf.stack()` before returning. 3524 if isinstance(ret, list): 3525 ret = tuple(ret) 3526 3527 try: 3528 self._output_structure = structure.type_spec_from_value(ret) 3529 except (ValueError, TypeError): 3530 six.reraise( 3531 TypeError, 3532 TypeError("Unsupported return value from function passed to " 3533 "%s: %s." % (transformation_name, ret)), 3534 sys.exc_info()[2]) 3535 return ret 3536 3537 if use_legacy_function: 3538 func_name = func_name + "_" + str(ops.uid()) 3539 3540 @function.Defun( 3541 *structure.get_flat_tensor_types(self._input_structure), 3542 func_name=func_name, 3543 **defun_kwargs) 3544 def wrapper_fn(*args): 3545 ret = _wrapper_helper(*args) 3546 # _warn_if_collections(transformation_name, ops.get_default_graph(), 0) 3547 return structure.to_tensor_list(self._output_structure, ret) 3548 3549 self._function = wrapper_fn 3550 resource_tracker = tracking.ResourceTracker() 3551 with tracking.resource_tracker_scope(resource_tracker): 3552 if add_to_graph: 3553 self._function.add_to_graph(ops.get_default_graph()) 3554 else: 3555 # Use the private method that will execute `wrapper_fn` but delay 3556 # adding it to the graph in case (e.g.) we need to rerun the function. 3557 self._function._create_definition_if_needed() 3558 if resource_tracker.resources: 3559 _warn_if_collections(transformation_name) 3560 3561 else: 3562 if def_function.functions_run_eagerly(): 3563 warnings.warn( 3564 "Even though the tf.config.experimental_run_functions_eagerly " 3565 "option is set, this option does not apply to tf.data functions. " 3566 "tf.data functions are still traced and executed as graphs.") 3567 3568 defun_kwargs.update({"func_name": func_name}) 3569 defun_kwargs.update({"_tf_data_function": True}) 3570 3571 # Note: _wrapper_helper will apply autograph based on context. 3572 @eager_function.defun_with_attributes( 3573 input_signature=structure.get_flat_tensor_specs( 3574 self._input_structure), 3575 autograph=False, 3576 attributes=defun_kwargs) 3577 def wrapper_fn(*args): # pylint: disable=missing-docstring 3578 ret = _wrapper_helper(*args) 3579 ret = structure.to_tensor_list(self._output_structure, ret) 3580 return [ops.convert_to_tensor(t) for t in ret] 3581 3582 resource_tracker = tracking.ResourceTracker() 3583 with tracking.resource_tracker_scope(resource_tracker): 3584 # TODO(b/141462134): Switch to using garbage collection. 3585 self._function = wrapper_fn.get_concrete_function() 3586 if add_to_graph: 3587 self._function.add_to_graph(ops.get_default_graph()) 3588 3589 if resource_tracker.resources: 3590 _warn_if_collections(transformation_name) 3591 3592 outer_graph_seed = ops.get_default_graph().seed 3593 if outer_graph_seed and self._function.graph.seed == outer_graph_seed: 3594 if self._function.graph._seed_used: 3595 warnings.warn( 3596 "Seed %s from outer graph might be getting used by function %s, " 3597 "if the random op has not been provided any seed. Explicitly set " 3598 "the seed in the function if this is not the intended behavior." 3599 %(outer_graph_seed, func_name), stacklevel=4) 3600 3601 @property 3602 def output_structure(self): 3603 return self._output_structure 3604 3605 @property 3606 def output_classes(self): 3607 return nest.map_structure( 3608 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 3609 self._output_structure) 3610 3611 @property 3612 def output_shapes(self): 3613 return nest.map_structure( 3614 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 3615 self._output_structure) 3616 3617 @property 3618 def output_types(self): 3619 return nest.map_structure( 3620 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 3621 self._output_structure) 3622 3623 @property 3624 def function(self): 3625 return self._function 3626 3627 3628class _GeneratorDataset(DatasetSource): 3629 """A `Dataset` that generates elements by invoking a function.""" 3630 3631 def __init__(self, init_args, init_func, next_func, finalize_func, 3632 output_signature): 3633 """Constructs a `_GeneratorDataset`. 3634 3635 Args: 3636 init_args: A nested structure representing the arguments to `init_func`. 3637 init_func: A TensorFlow function that will be called on `init_args` each 3638 time a C++ iterator over this dataset is constructed. Returns a nested 3639 structure representing the "state" of the dataset. 3640 next_func: A TensorFlow function that will be called on the result of 3641 `init_func` to produce each element, and that raises `OutOfRangeError` 3642 to terminate iteration. 3643 finalize_func: A TensorFlow function that will be called on the result of 3644 `init_func` immediately before a C++ iterator over this dataset is 3645 destroyed. The return value is ignored. 3646 output_signature: A nested structure of `tf.TypeSpec` objects describing 3647 the output of `next_func`. 3648 """ 3649 self._init_args = init_args 3650 3651 self._init_structure = structure.type_spec_from_value(init_args) 3652 3653 self._init_func = StructuredFunctionWrapper( 3654 init_func, 3655 self._transformation_name(), 3656 input_structure=self._init_structure) 3657 3658 self._next_func = StructuredFunctionWrapper( 3659 next_func, 3660 self._transformation_name(), 3661 input_structure=self._init_func.output_structure) 3662 3663 self._finalize_func = StructuredFunctionWrapper( 3664 finalize_func, 3665 self._transformation_name(), 3666 input_structure=self._init_func.output_structure) 3667 3668 self._output_signature = output_signature 3669 3670 variant_tensor = gen_dataset_ops.generator_dataset( 3671 structure.to_tensor_list(self._init_structure, self._init_args) + 3672 self._init_func.function.captured_inputs, 3673 self._next_func.function.captured_inputs, 3674 self._finalize_func.function.captured_inputs, 3675 init_func=self._init_func.function, 3676 next_func=self._next_func.function, 3677 finalize_func=self._finalize_func.function, 3678 **self._flat_structure) 3679 super(_GeneratorDataset, self).__init__(variant_tensor) 3680 3681 @property 3682 def element_spec(self): 3683 return self._output_signature 3684 3685 def _transformation_name(self): 3686 return "Dataset.from_generator()" 3687 3688 3689class ZipDataset(DatasetV2): 3690 """A `Dataset` that zips its inputs together.""" 3691 3692 def __init__(self, datasets): 3693 """See `Dataset.zip()` for details.""" 3694 for ds in nest.flatten(datasets): 3695 if not isinstance(ds, DatasetV2): 3696 if isinstance(ds, list): 3697 message = ("The argument to `Dataset.zip()` must be a nested " 3698 "structure of `Dataset` objects. Nested structures do not " 3699 "support Python lists; please use a tuple instead.") 3700 else: 3701 message = ("The argument to `Dataset.zip()` must be a nested " 3702 "structure of `Dataset` objects.") 3703 raise TypeError(message) 3704 self._datasets = datasets 3705 self._structure = nest.pack_sequence_as( 3706 self._datasets, 3707 [ds.element_spec for ds in nest.flatten(self._datasets)]) 3708 variant_tensor = gen_dataset_ops.zip_dataset( 3709 [ds._variant_tensor for ds in nest.flatten(self._datasets)], 3710 **self._flat_structure) 3711 super(ZipDataset, self).__init__(variant_tensor) 3712 3713 def _inputs(self): 3714 return nest.flatten(self._datasets) 3715 3716 @property 3717 def element_spec(self): 3718 return self._structure 3719 3720 3721class ConcatenateDataset(DatasetV2): 3722 """A `Dataset` that concatenates its input with given dataset.""" 3723 3724 def __init__(self, input_dataset, dataset_to_concatenate): 3725 """See `Dataset.concatenate()` for details.""" 3726 self._input_dataset = input_dataset 3727 self._dataset_to_concatenate = dataset_to_concatenate 3728 3729 output_types = get_legacy_output_types(input_dataset) 3730 if output_types != get_legacy_output_types(dataset_to_concatenate): 3731 raise TypeError( 3732 "Two datasets to concatenate have different types %s and %s" % 3733 (output_types, get_legacy_output_types(dataset_to_concatenate))) 3734 3735 output_classes = get_legacy_output_classes(input_dataset) 3736 if output_classes != get_legacy_output_classes(dataset_to_concatenate): 3737 raise TypeError( 3738 "Two datasets to concatenate have different classes %s and %s" % 3739 (output_classes, get_legacy_output_classes(dataset_to_concatenate))) 3740 3741 input_shapes = get_legacy_output_shapes(self._input_dataset) 3742 output_shapes = nest.pack_sequence_as(input_shapes, [ 3743 ts1.most_specific_compatible_shape(ts2) 3744 for (ts1, ts2) in zip( 3745 nest.flatten(input_shapes), 3746 nest.flatten(get_legacy_output_shapes( 3747 self._dataset_to_concatenate))) 3748 ]) 3749 3750 self._structure = structure.convert_legacy_structure( 3751 output_types, output_shapes, output_classes) 3752 3753 self._input_datasets = [input_dataset, dataset_to_concatenate] 3754 # pylint: disable=protected-access 3755 variant_tensor = gen_dataset_ops.concatenate_dataset( 3756 input_dataset._variant_tensor, dataset_to_concatenate._variant_tensor, 3757 **self._flat_structure) 3758 # pylint: enable=protected-access 3759 super(ConcatenateDataset, self).__init__(variant_tensor) 3760 3761 def _inputs(self): 3762 return self._input_datasets 3763 3764 @property 3765 def element_spec(self): 3766 return self._structure 3767 3768 3769class RepeatDataset(UnaryUnchangedStructureDataset): 3770 """A `Dataset` that repeats its input several times.""" 3771 3772 def __init__(self, input_dataset, count): 3773 """See `Dataset.repeat()` for details.""" 3774 self._input_dataset = input_dataset 3775 if count is None: 3776 self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") 3777 else: 3778 self._count = ops.convert_to_tensor( 3779 count, dtype=dtypes.int64, name="count") 3780 variant_tensor = gen_dataset_ops.repeat_dataset( 3781 input_dataset._variant_tensor, # pylint: disable=protected-access 3782 count=self._count, 3783 **self._flat_structure) 3784 super(RepeatDataset, self).__init__(input_dataset, variant_tensor) 3785 3786 3787class RangeDataset(DatasetSource): 3788 """A `Dataset` of a step separated range of values.""" 3789 3790 def __init__(self, *args, **kwargs): 3791 """See `Dataset.range()` for details.""" 3792 self._parse_args(*args, **kwargs) 3793 self._structure = tensor_spec.TensorSpec([], self._output_type) 3794 variant_tensor = gen_dataset_ops.range_dataset( 3795 start=self._start, 3796 stop=self._stop, 3797 step=self._step, 3798 **self._flat_structure) 3799 super(RangeDataset, self).__init__(variant_tensor) 3800 3801 def _parse_args(self, *args, **kwargs): 3802 """Parse arguments according to the same rules as the `range()` builtin.""" 3803 if len(args) == 1: 3804 self._start = self._build_tensor(0, "start") 3805 self._stop = self._build_tensor(args[0], "stop") 3806 self._step = self._build_tensor(1, "step") 3807 elif len(args) == 2: 3808 self._start = self._build_tensor(args[0], "start") 3809 self._stop = self._build_tensor(args[1], "stop") 3810 self._step = self._build_tensor(1, "step") 3811 elif len(args) == 3: 3812 self._start = self._build_tensor(args[0], "start") 3813 self._stop = self._build_tensor(args[1], "stop") 3814 self._step = self._build_tensor(args[2], "step") 3815 else: 3816 raise ValueError("Invalid arguments to RangeDataset: %s" % str(args)) 3817 if "output_type" in kwargs: 3818 self._output_type = kwargs["output_type"] 3819 else: 3820 self._output_type = dtypes.int64 3821 3822 def _build_tensor(self, int64_value, name): 3823 return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name) 3824 3825 @property 3826 def element_spec(self): 3827 return self._structure 3828 3829 3830class CacheDataset(UnaryUnchangedStructureDataset): 3831 """A `Dataset` that caches elements of its input.""" 3832 3833 def __init__(self, input_dataset, filename): 3834 """See `Dataset.cache()` for details.""" 3835 self._input_dataset = input_dataset 3836 self._filename = ops.convert_to_tensor( 3837 filename, dtype=dtypes.string, name="filename") 3838 if tf2.enabled() and (context.executing_eagerly() or ops.inside_function()): 3839 variant_tensor = gen_dataset_ops.cache_dataset_v2( 3840 input_dataset._variant_tensor, # pylint: disable=protected-access 3841 filename=self._filename, 3842 cache=gen_dataset_ops.dummy_memory_cache(), 3843 **self._flat_structure) 3844 else: 3845 variant_tensor = gen_dataset_ops.cache_dataset( 3846 input_dataset._variant_tensor, # pylint: disable=protected-access 3847 filename=self._filename, 3848 **self._flat_structure) 3849 super(CacheDataset, self).__init__(input_dataset, variant_tensor) 3850 3851 3852class ShuffleDataset(UnaryUnchangedStructureDataset): 3853 """A `Dataset` that randomly shuffles the elements of its input.""" 3854 3855 def __init__(self, 3856 input_dataset, 3857 buffer_size, 3858 seed=None, 3859 reshuffle_each_iteration=None): 3860 """Randomly shuffles the elements of this dataset. 3861 3862 Args: 3863 input_dataset: The input dataset. 3864 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 3865 elements from this dataset from which the new dataset will sample. 3866 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 3867 seed that will be used to create the distribution. See 3868 `tf.random.set_seed` for behavior. 3869 reshuffle_each_iteration: (Optional.) A boolean, which if true indicates 3870 that the dataset should be pseudorandomly reshuffled each time it is 3871 iterated over. (Defaults to `True`.) 3872 3873 Returns: 3874 A `Dataset`. 3875 3876 Raises: 3877 ValueError: if invalid arguments are provided. 3878 """ 3879 self._input_dataset = input_dataset 3880 self._buffer_size = ops.convert_to_tensor( 3881 buffer_size, dtype=dtypes.int64, name="buffer_size") 3882 self._seed, self._seed2 = random_seed.get_seed(seed) 3883 if reshuffle_each_iteration is None: 3884 reshuffle_each_iteration = True 3885 self._reshuffle_each_iteration = reshuffle_each_iteration 3886 3887 if (tf2.enabled() and 3888 (context.executing_eagerly() or ops.inside_function())): 3889 variant_tensor = gen_dataset_ops.shuffle_dataset_v3( 3890 input_dataset._variant_tensor, # pylint: disable=protected-access 3891 buffer_size=self._buffer_size, 3892 seed=self._seed, 3893 seed2=self._seed2, 3894 seed_generator=gen_dataset_ops.dummy_seed_generator(), 3895 reshuffle_each_iteration=self._reshuffle_each_iteration, 3896 **self._flat_structure) 3897 else: 3898 variant_tensor = gen_dataset_ops.shuffle_dataset( 3899 input_dataset._variant_tensor, # pylint: disable=protected-access 3900 buffer_size=self._buffer_size, 3901 seed=self._seed, 3902 seed2=self._seed2, 3903 reshuffle_each_iteration=self._reshuffle_each_iteration, 3904 **self._flat_structure) 3905 super(ShuffleDataset, self).__init__(input_dataset, variant_tensor) 3906 3907 3908class TakeDataset(UnaryUnchangedStructureDataset): 3909 """A `Dataset` containing the first `count` elements from its input.""" 3910 3911 def __init__(self, input_dataset, count): 3912 """See `Dataset.take()` for details.""" 3913 self._input_dataset = input_dataset 3914 self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") 3915 variant_tensor = gen_dataset_ops.take_dataset( 3916 input_dataset._variant_tensor, # pylint: disable=protected-access 3917 count=self._count, 3918 **self._flat_structure) 3919 super(TakeDataset, self).__init__(input_dataset, variant_tensor) 3920 3921 3922class SkipDataset(UnaryUnchangedStructureDataset): 3923 """A `Dataset` skipping the first `count` elements from its input.""" 3924 3925 def __init__(self, input_dataset, count): 3926 """See `Dataset.skip()` for details.""" 3927 self._input_dataset = input_dataset 3928 self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") 3929 variant_tensor = gen_dataset_ops.skip_dataset( 3930 input_dataset._variant_tensor, # pylint: disable=protected-access 3931 count=self._count, 3932 **self._flat_structure) 3933 super(SkipDataset, self).__init__(input_dataset, variant_tensor) 3934 3935 3936class ShardDataset(UnaryUnchangedStructureDataset): 3937 """A `Dataset` for sharding its input.""" 3938 3939 def __init__(self, input_dataset, num_shards, index): 3940 """See `Dataset.shard()` for details.""" 3941 self._input_dataset = input_dataset 3942 self._num_shards = ops.convert_to_tensor( 3943 num_shards, dtype=dtypes.int64, name="num_shards") 3944 self._index = ops.convert_to_tensor(index, dtype=dtypes.int64, name="index") 3945 variant_tensor = gen_dataset_ops.shard_dataset( 3946 input_dataset._variant_tensor, # pylint: disable=protected-access 3947 num_shards=self._num_shards, 3948 index=self._index, 3949 **self._flat_structure) 3950 super(ShardDataset, self).__init__(input_dataset, variant_tensor) 3951 3952 3953class BatchDataset(UnaryDataset): 3954 """A `Dataset` that batches contiguous elements from its input.""" 3955 3956 def __init__(self, input_dataset, batch_size, drop_remainder): 3957 """See `Dataset.batch()` for details.""" 3958 self._input_dataset = input_dataset 3959 self._batch_size = ops.convert_to_tensor( 3960 batch_size, dtype=dtypes.int64, name="batch_size") 3961 self._drop_remainder = ops.convert_to_tensor( 3962 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 3963 3964 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder) 3965 # pylint: disable=protected-access 3966 if constant_drop_remainder: 3967 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) 3968 # or `False` (explicitly retaining the remainder). 3969 # pylint: disable=g-long-lambda 3970 constant_batch_size = tensor_util.constant_value(self._batch_size) 3971 self._structure = nest.map_structure( 3972 lambda component_spec: component_spec._batch(constant_batch_size), 3973 input_dataset.element_spec) 3974 else: 3975 self._structure = nest.map_structure( 3976 lambda component_spec: component_spec._batch(None), 3977 input_dataset.element_spec) 3978 variant_tensor = gen_dataset_ops.batch_dataset_v2( 3979 input_dataset._variant_tensor, 3980 batch_size=self._batch_size, 3981 drop_remainder=self._drop_remainder, 3982 **self._flat_structure) 3983 super(BatchDataset, self).__init__(input_dataset, variant_tensor) 3984 3985 @property 3986 def element_spec(self): 3987 return self._structure 3988 3989 3990class ParallelBatchDataset(UnaryDataset): 3991 """A `Dataset` that batches contiguous elements from its input in parallel.""" 3992 3993 def __init__(self, input_dataset, batch_size, drop_remainder, 3994 num_parallel_calls): 3995 """See `Dataset.batch()` for details.""" 3996 self._input_dataset = input_dataset 3997 self._batch_size = ops.convert_to_tensor( 3998 batch_size, dtype=dtypes.int64, name="batch_size") 3999 self._drop_remainder = ops.convert_to_tensor( 4000 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 4001 self._num_parallel_calls = ops.convert_to_tensor( 4002 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 4003 4004 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder) 4005 # pylint: disable=protected-access 4006 if constant_drop_remainder: 4007 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) 4008 # or `False` (explicitly retaining the remainder). 4009 # pylint: disable=g-long-lambda 4010 constant_batch_size = tensor_util.constant_value(self._batch_size) 4011 self._structure = nest.map_structure( 4012 lambda component_spec: component_spec._batch(constant_batch_size), 4013 input_dataset.element_spec) 4014 else: 4015 self._structure = nest.map_structure( 4016 lambda component_spec: component_spec._batch(None), 4017 input_dataset.element_spec) 4018 variant_tensor = gen_dataset_ops.parallel_batch_dataset( 4019 input_dataset._variant_tensor, 4020 batch_size=self._batch_size, 4021 num_parallel_calls=self._num_parallel_calls, 4022 drop_remainder=self._drop_remainder, 4023 **self._flat_structure) 4024 super(ParallelBatchDataset, self).__init__(input_dataset, variant_tensor) 4025 4026 @property 4027 def element_spec(self): 4028 return self._structure 4029 4030 4031class _NumpyIterator(object): 4032 """Iterator over a dataset with elements converted to numpy.""" 4033 4034 __slots__ = ["_iterator"] 4035 4036 def __init__(self, dataset): 4037 self._iterator = iter(dataset) 4038 4039 def __iter__(self): 4040 return self 4041 4042 def __next__(self): 4043 4044 def to_numpy(x): 4045 numpy = x._numpy() # pylint: disable=protected-access 4046 if isinstance(numpy, np.ndarray): 4047 # `numpy` shares the same underlying buffer as the `x` Tensor. 4048 # Tensors are expected to be immutable, so we disable writes. 4049 numpy.setflags(write=False) 4050 return numpy 4051 4052 return nest.map_structure(to_numpy, next(self._iterator)) 4053 4054 def next(self): 4055 return self.__next__() 4056 4057 4058class _VariantTracker(tracking.CapturableResource): 4059 """Allows export of functions capturing a Dataset in SavedModels. 4060 4061 When saving a SavedModel, `tf.saved_model.save` traverses the object 4062 graph. Since Datasets reference _VariantTracker objects, that traversal will 4063 find a _VariantTracker for each Dataset and so know how to save and restore 4064 functions which reference the Dataset's variant Tensor. 4065 """ 4066 4067 def __init__(self, variant_tensor, resource_creator): 4068 """Record that `variant_tensor` is associated with `resource_creator`. 4069 4070 Args: 4071 variant_tensor: The variant-dtype Tensor associated with the Dataset. This 4072 Tensor will be a captured input to functions which use the Dataset, and 4073 is used by saving code to identify the corresponding _VariantTracker. 4074 resource_creator: A zero-argument function which creates a new 4075 variant-dtype Tensor. This function will be included in SavedModels and 4076 run to re-create the Dataset's variant Tensor on restore. 4077 """ 4078 super(_VariantTracker, self).__init__(device="CPU") 4079 self._resource_handle = variant_tensor 4080 self._create_resource = resource_creator 4081 4082 4083def _is_padded_shape_compatible_with(padded_shape, input_component_shape): 4084 """Returns `True` if `input_component_shape` can be padded to `padded_shape`. 4085 4086 Args: 4087 padded_shape: A `tf.TensorShape`. 4088 input_component_shape: A `tf.TensorShape`. 4089 4090 Returns: 4091 `True` if `input_component_shape` can be padded to `padded_shape`, otherwise 4092 `False`. 4093 """ 4094 4095 if padded_shape.dims is None or input_component_shape.dims is None: 4096 return True 4097 if len(padded_shape.dims) != len(input_component_shape.dims): 4098 return False 4099 for padded_dim, input_dim in zip( 4100 padded_shape.dims, input_component_shape.dims): 4101 if (padded_dim.value is not None and input_dim.value is not None 4102 and padded_dim.value < input_dim.value): 4103 return False 4104 return True 4105 4106 4107def _padded_shape_to_tensor(padded_shape, input_component_shape): 4108 """Converts `padded_shape` to a `tf.Tensor` representing that shape. 4109 4110 Args: 4111 padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python 4112 sequence, or a 1-D `tf.Tensor` of `tf.int64` elements. 4113 input_component_shape: A `tf.TensorShape`, with which `padded_shape` must 4114 be compatible. 4115 4116 Returns: 4117 A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`. 4118 4119 Raises: 4120 ValueError: If `padded_shape` is not a shape or not compatible with 4121 `input_component_shape`. 4122 TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor. 4123 """ 4124 try: 4125 # Try to convert the `padded_shape` to a `tf.TensorShape` 4126 padded_shape_as_shape = tensor_shape.as_shape(padded_shape) 4127 # We will return the "canonical" tensor representation, which uses 4128 # `-1` in place of `None`. 4129 ret = ops.convert_to_tensor( 4130 [dim if dim is not None else -1 4131 for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64) 4132 except (TypeError, ValueError): 4133 # The argument was not trivially convertible to a 4134 # `tf.TensorShape`, so fall back on the conversion to tensor 4135 # machinery. 4136 ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64) 4137 if ret.shape.dims is not None and len(ret.shape.dims) != 1: 4138 six.reraise(ValueError, ValueError( 4139 "Padded shape %s must be a 1-D tensor of tf.int64 values, but its " 4140 "shape was %s." % (padded_shape, ret.shape)), sys.exc_info()[2]) 4141 if ret.dtype != dtypes.int64: 4142 six.reraise( 4143 TypeError, 4144 TypeError( 4145 "Padded shape %s must be a 1-D tensor of tf.int64 values, but " 4146 "its element type was %s." % (padded_shape, ret.dtype.name)), 4147 sys.exc_info()[2]) 4148 padded_shape_as_shape = tensor_util.constant_value_as_shape(ret) 4149 4150 if not _is_padded_shape_compatible_with(padded_shape_as_shape, 4151 input_component_shape): 4152 raise ValueError("The padded shape %s is not compatible with the " 4153 "corresponding input component shape %s." 4154 % (padded_shape_as_shape, input_component_shape)) 4155 4156 return ret 4157 4158 4159def _padding_value_to_tensor(value, output_type): 4160 """Converts the padding value to a tensor. 4161 4162 Args: 4163 value: The padding value. 4164 output_type: Its expected dtype. 4165 4166 Returns: 4167 A scalar `Tensor`. 4168 4169 Raises: 4170 ValueError: if the padding value is not a scalar. 4171 TypeError: if the padding value's type does not match `output_type`. 4172 """ 4173 value = ops.convert_to_tensor(value, name="padding_value") 4174 if not value.shape.is_compatible_with(tensor_shape.TensorShape([])): 4175 raise ValueError("Padding value should be a scalar, but is not: %s" % value) 4176 if value.dtype != output_type: 4177 raise TypeError("Padding value tensor (%s) does not match output type: %s" % 4178 (value, output_type)) 4179 return value 4180 4181 4182def _padding_values_or_default(padding_values, input_dataset): 4183 """Returns padding values with None elements replaced with default values.""" 4184 4185 def make_zero(t): 4186 if t.base_dtype == dtypes.string: 4187 return "" 4188 elif t.base_dtype == dtypes.variant: 4189 error_msg = ("Unable to create padding for field of type 'variant' " 4190 "because t.base_type == dtypes.variant == " 4191 "{}.".format(t.base_dtype)) 4192 raise TypeError(error_msg) 4193 elif t.base_dtype == dtypes.bfloat16: 4194 # Special case `bfloat16` because it is not supported by NumPy. 4195 return constant_op.constant(0, dtype=dtypes.bfloat16) 4196 else: 4197 return np.zeros_like(t.as_numpy_dtype()) 4198 4199 def value_or_default(value, default): 4200 return default if value is None else value 4201 4202 default_padding = nest.map_structure( 4203 make_zero, 4204 get_legacy_output_types(input_dataset)) 4205 return nest.map_structure_up_to(padding_values, value_or_default, 4206 padding_values, default_padding) 4207 4208 4209class PaddedBatchDataset(UnaryDataset): 4210 """A `Dataset` that batches and pads contiguous elements from its input.""" 4211 4212 def __init__(self, input_dataset, batch_size, padded_shapes, padding_values, 4213 drop_remainder): 4214 """See `Dataset.batch()` for details.""" 4215 self._input_dataset = input_dataset 4216 4217 def check_types(component_spec): 4218 if not isinstance(component_spec, tensor_spec.TensorSpec): 4219 raise TypeError("Padded batching of components of type ", 4220 type(component_spec), " is not supported.") 4221 4222 nest.map_structure(check_types, input_dataset.element_spec) 4223 self._input_dataset = input_dataset 4224 self._batch_size = ops.convert_to_tensor( 4225 batch_size, dtype=dtypes.int64, name="batch_size") 4226 padding_values = _padding_values_or_default(padding_values, input_dataset) 4227 4228 input_shapes = get_legacy_output_shapes(input_dataset) 4229 flat_padded_shapes = nest.flatten_up_to(input_shapes, padded_shapes) 4230 4231 flat_padded_shapes_as_tensors = [] 4232 4233 for input_component_shape, padded_shape in zip( 4234 nest.flatten(input_shapes), flat_padded_shapes): 4235 flat_padded_shapes_as_tensors.append( 4236 _padded_shape_to_tensor(padded_shape, input_component_shape)) 4237 4238 self._padded_shapes = nest.pack_sequence_as(input_shapes, 4239 flat_padded_shapes_as_tensors) 4240 4241 # If padding_values is a single element and input_shapes is a structure, 4242 # "broadcast" padding_values to the same structure as input_shapes. 4243 if nest.is_sequence(input_shapes) and not nest.is_sequence(padding_values): 4244 padding_values = nest.map_structure(lambda _: padding_values, 4245 input_shapes) 4246 4247 self._padding_values = nest.map_structure_up_to( 4248 input_shapes, _padding_value_to_tensor, padding_values, 4249 get_legacy_output_types(input_dataset)) 4250 self._drop_remainder = ops.convert_to_tensor( 4251 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 4252 4253 def _padded_shape_to_batch_shape(s): 4254 return tensor_shape.TensorShape([ 4255 tensor_util.constant_value(self._batch_size) 4256 if smart_cond.smart_constant_value(self._drop_remainder) else None 4257 ]).concatenate(tensor_util.constant_value_as_shape(s)) 4258 4259 output_shapes = nest.map_structure( 4260 _padded_shape_to_batch_shape, self._padded_shapes) 4261 self._structure = structure.convert_legacy_structure( 4262 get_legacy_output_types(self._input_dataset), output_shapes, 4263 get_legacy_output_classes(self._input_dataset)) 4264 4265 # pylint: disable=protected-access 4266 # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018. 4267 if smart_cond.smart_constant_value(self._drop_remainder) is False: 4268 variant_tensor = gen_dataset_ops.padded_batch_dataset( 4269 input_dataset._variant_tensor, # pylint: disable=protected-access 4270 batch_size=self._batch_size, 4271 padded_shapes=[ 4272 ops.convert_to_tensor(s, dtype=dtypes.int64) 4273 for s in nest.flatten(self._padded_shapes) 4274 ], 4275 padding_values=nest.flatten(self._padding_values), 4276 output_shapes=structure.get_flat_tensor_shapes(self._structure)) 4277 else: 4278 variant_tensor = gen_dataset_ops.padded_batch_dataset_v2( 4279 input_dataset._variant_tensor, # pylint: disable=protected-access 4280 batch_size=self._batch_size, 4281 padded_shapes=[ 4282 ops.convert_to_tensor(s, dtype=dtypes.int64) 4283 for s in nest.flatten(self._padded_shapes) 4284 ], 4285 padding_values=nest.flatten(self._padding_values), 4286 drop_remainder=self._drop_remainder, 4287 output_shapes=structure.get_flat_tensor_shapes(self._structure)) 4288 super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor) 4289 4290 @property 4291 def element_spec(self): 4292 return self._structure 4293 4294 4295def _should_unpack_args(args): 4296 """Returns `True` if `args` should be `*args` when passed to a callable.""" 4297 return type(args) is tuple # pylint: disable=unidiomatic-typecheck 4298 4299 4300class MapDataset(UnaryDataset): 4301 """A `Dataset` that maps a function over elements in its input.""" 4302 4303 def __init__(self, 4304 input_dataset, 4305 map_func, 4306 use_inter_op_parallelism=True, 4307 preserve_cardinality=False, 4308 use_legacy_function=False): 4309 """See `Dataset.map()` for details.""" 4310 self._input_dataset = input_dataset 4311 self._use_inter_op_parallelism = use_inter_op_parallelism 4312 self._preserve_cardinality = preserve_cardinality 4313 self._map_func = StructuredFunctionWrapper( 4314 map_func, 4315 self._transformation_name(), 4316 dataset=input_dataset, 4317 use_legacy_function=use_legacy_function) 4318 variant_tensor = gen_dataset_ops.map_dataset( 4319 input_dataset._variant_tensor, # pylint: disable=protected-access 4320 self._map_func.function.captured_inputs, 4321 f=self._map_func.function, 4322 use_inter_op_parallelism=self._use_inter_op_parallelism, 4323 preserve_cardinality=self._preserve_cardinality, 4324 **self._flat_structure) 4325 super(MapDataset, self).__init__(input_dataset, variant_tensor) 4326 4327 def _functions(self): 4328 return [self._map_func] 4329 4330 @property 4331 def element_spec(self): 4332 return self._map_func.output_structure 4333 4334 def _transformation_name(self): 4335 return "Dataset.map()" 4336 4337 4338class ParallelMapDataset(UnaryDataset): 4339 """A `Dataset` that maps a function over elements in its input in parallel.""" 4340 4341 def __init__(self, 4342 input_dataset, 4343 map_func, 4344 num_parallel_calls, 4345 deterministic, 4346 use_inter_op_parallelism=True, 4347 preserve_cardinality=False, 4348 use_legacy_function=False): 4349 """See `Dataset.map()` for details.""" 4350 self._input_dataset = input_dataset 4351 self._use_inter_op_parallelism = use_inter_op_parallelism 4352 self._map_func = StructuredFunctionWrapper( 4353 map_func, 4354 self._transformation_name(), 4355 dataset=input_dataset, 4356 use_legacy_function=use_legacy_function) 4357 if deterministic is None: 4358 self._deterministic = "default" 4359 elif deterministic: 4360 self._deterministic = "true" 4361 else: 4362 self._deterministic = "false" 4363 self._preserve_cardinality = preserve_cardinality 4364 self._num_parallel_calls = ops.convert_to_tensor( 4365 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 4366 variant_tensor = gen_dataset_ops.parallel_map_dataset_v2( 4367 input_dataset._variant_tensor, # pylint: disable=protected-access 4368 self._map_func.function.captured_inputs, 4369 f=self._map_func.function, 4370 num_parallel_calls=self._num_parallel_calls, 4371 deterministic=self._deterministic, 4372 use_inter_op_parallelism=self._use_inter_op_parallelism, 4373 preserve_cardinality=self._preserve_cardinality, 4374 **self._flat_structure) 4375 super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor) 4376 4377 def _functions(self): 4378 return [self._map_func] 4379 4380 @property 4381 def element_spec(self): 4382 return self._map_func.output_structure 4383 4384 def _transformation_name(self): 4385 return "Dataset.map()" 4386 4387 4388class FlatMapDataset(UnaryDataset): 4389 """A `Dataset` that maps a function over its input and flattens the result.""" 4390 4391 def __init__(self, input_dataset, map_func): 4392 """See `Dataset.flat_map()` for details.""" 4393 self._input_dataset = input_dataset 4394 self._map_func = StructuredFunctionWrapper( 4395 map_func, self._transformation_name(), dataset=input_dataset) 4396 if not isinstance(self._map_func.output_structure, DatasetSpec): 4397 raise TypeError( 4398 "`map_func` must return a `Dataset` object. Got {}".format( 4399 type(self._map_func.output_structure))) 4400 self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access 4401 variant_tensor = gen_dataset_ops.flat_map_dataset( 4402 input_dataset._variant_tensor, # pylint: disable=protected-access 4403 self._map_func.function.captured_inputs, 4404 f=self._map_func.function, 4405 **self._flat_structure) 4406 super(FlatMapDataset, self).__init__(input_dataset, variant_tensor) 4407 4408 def _functions(self): 4409 return [self._map_func] 4410 4411 @property 4412 def element_spec(self): 4413 return self._structure 4414 4415 def _transformation_name(self): 4416 return "Dataset.flat_map()" 4417 4418 4419class InterleaveDataset(UnaryDataset): 4420 """A `Dataset` that interleaves the result of transformed inputs.""" 4421 4422 def __init__(self, input_dataset, map_func, cycle_length, block_length): 4423 """See `Dataset.interleave()` for details.""" 4424 4425 self._input_dataset = input_dataset 4426 self._map_func = StructuredFunctionWrapper( 4427 map_func, self._transformation_name(), dataset=input_dataset) 4428 if not isinstance(self._map_func.output_structure, DatasetSpec): 4429 raise TypeError( 4430 "`map_func` must return a `Dataset` object. Got {}".format( 4431 type(self._map_func.output_structure))) 4432 self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access 4433 self._cycle_length = ops.convert_to_tensor( 4434 cycle_length, dtype=dtypes.int64, name="cycle_length") 4435 self._block_length = ops.convert_to_tensor( 4436 block_length, dtype=dtypes.int64, name="block_length") 4437 4438 variant_tensor = gen_dataset_ops.interleave_dataset( 4439 input_dataset._variant_tensor, # pylint: disable=protected-access 4440 self._map_func.function.captured_inputs, # pylint: disable=protected-access 4441 self._cycle_length, 4442 self._block_length, 4443 f=self._map_func.function, 4444 **self._flat_structure) 4445 super(InterleaveDataset, self).__init__(input_dataset, variant_tensor) 4446 4447 def _functions(self): 4448 return [self._map_func] 4449 4450 @property 4451 def element_spec(self): 4452 return self._structure 4453 4454 def _transformation_name(self): 4455 return "Dataset.interleave()" 4456 4457 4458class ParallelInterleaveDataset(UnaryDataset): 4459 """A `Dataset` that maps a function over its input and interleaves the result.""" 4460 4461 def __init__(self, 4462 input_dataset, 4463 map_func, 4464 cycle_length, 4465 block_length, 4466 num_parallel_calls, 4467 buffer_output_elements=AUTOTUNE, 4468 prefetch_input_elements=AUTOTUNE, 4469 deterministic=None): 4470 """See `Dataset.interleave()` for details.""" 4471 self._input_dataset = input_dataset 4472 self._map_func = StructuredFunctionWrapper( 4473 map_func, self._transformation_name(), dataset=input_dataset) 4474 if not isinstance(self._map_func.output_structure, DatasetSpec): 4475 raise TypeError( 4476 "`map_func` must return a `Dataset` object. Got {}".format( 4477 type(self._map_func.output_structure))) 4478 self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access 4479 self._cycle_length = ops.convert_to_tensor( 4480 cycle_length, dtype=dtypes.int64, name="cycle_length") 4481 self._block_length = ops.convert_to_tensor( 4482 block_length, dtype=dtypes.int64, name="block_length") 4483 self._buffer_output_elements = ops.convert_to_tensor( 4484 buffer_output_elements, 4485 dtype=dtypes.int64, 4486 name="buffer_output_elements") 4487 self._prefetch_input_elements = ops.convert_to_tensor( 4488 prefetch_input_elements, 4489 dtype=dtypes.int64, 4490 name="prefetch_input_elements") 4491 4492 self._num_parallel_calls = ops.convert_to_tensor( 4493 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 4494 if deterministic is None: 4495 deterministic_string = "default" 4496 elif deterministic: 4497 deterministic_string = "true" 4498 else: 4499 deterministic_string = "false" 4500 4501 variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4( 4502 input_dataset._variant_tensor, # pylint: disable=protected-access 4503 self._map_func.function.captured_inputs, # pylint: disable=protected-access 4504 self._cycle_length, 4505 self._block_length, 4506 self._buffer_output_elements, 4507 self._prefetch_input_elements, 4508 self._num_parallel_calls, 4509 f=self._map_func.function, 4510 deterministic=deterministic_string, 4511 **self._flat_structure) 4512 super(ParallelInterleaveDataset, self).__init__(input_dataset, 4513 variant_tensor) 4514 4515 def _functions(self): 4516 return [self._map_func] 4517 4518 @property 4519 def element_spec(self): 4520 return self._structure 4521 4522 def _transformation_name(self): 4523 return "Dataset.interleave()" 4524 4525 4526class FilterDataset(UnaryUnchangedStructureDataset): 4527 """A `Dataset` that filters its input according to a predicate function.""" 4528 4529 def __init__(self, input_dataset, predicate, use_legacy_function=False): 4530 """See `Dataset.filter()` for details.""" 4531 self._input_dataset = input_dataset 4532 wrapped_func = StructuredFunctionWrapper( 4533 predicate, 4534 self._transformation_name(), 4535 dataset=input_dataset, 4536 use_legacy_function=use_legacy_function) 4537 if not wrapped_func.output_structure.is_compatible_with( 4538 tensor_spec.TensorSpec([], dtypes.bool)): 4539 error_msg = ("`predicate` return type must be convertible to a scalar " 4540 "boolean tensor. Was {}.").format( 4541 wrapped_func.output_structure) 4542 raise ValueError(error_msg) 4543 self._predicate = wrapped_func 4544 variant_tensor = gen_dataset_ops.filter_dataset( 4545 input_dataset._variant_tensor, # pylint: disable=protected-access 4546 other_arguments=self._predicate.function.captured_inputs, 4547 predicate=self._predicate.function, 4548 **self._flat_structure) 4549 super(FilterDataset, self).__init__(input_dataset, variant_tensor) 4550 4551 def _functions(self): 4552 return [self._predicate] 4553 4554 def _transformation_name(self): 4555 return "Dataset.filter()" 4556 4557 4558class PrefetchDataset(UnaryUnchangedStructureDataset): 4559 """A `Dataset` that asynchronously prefetches its input.""" 4560 4561 def __init__(self, input_dataset, buffer_size, slack_period=None): 4562 """See `Dataset.prefetch()` for details. 4563 4564 Args: 4565 input_dataset: The input dataset. 4566 buffer_size: See `Dataset.prefetch()` for details. 4567 slack_period: (Optional.) An integer. If non-zero, determines the number 4568 of GetNext calls before injecting slack into the execution. This may 4569 reduce CPU contention at the start of a step. Note that a tensorflow 4570 user should not have to set this manually; enable this behavior 4571 automatically via `tf.data.Options.experimental_slack` instead. Defaults 4572 to None. 4573 """ 4574 self._input_dataset = input_dataset 4575 if buffer_size is None: 4576 buffer_size = AUTOTUNE 4577 self._buffer_size = ops.convert_to_tensor( 4578 buffer_size, dtype=dtypes.int64, name="buffer_size") 4579 # pylint: disable=protected-access 4580 # We colocate the prefetch dataset with its input as this collocation only 4581 # happens automatically in graph mode. 4582 with ops.colocate_with(input_dataset._variant_tensor): 4583 variant_tensor = gen_dataset_ops.prefetch_dataset( 4584 input_dataset._variant_tensor, 4585 buffer_size=self._buffer_size, 4586 slack_period=slack_period, 4587 **self._flat_structure) 4588 super(PrefetchDataset, self).__init__(input_dataset, variant_tensor) 4589 4590 4591class WindowDataset(UnaryDataset): 4592 """A dataset that creates window datasets from the input elements.""" 4593 4594 def __init__(self, input_dataset, size, shift, stride, drop_remainder): 4595 """See `window_dataset()` for more details.""" 4596 self._input_dataset = input_dataset 4597 self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size") 4598 self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift") 4599 self._stride = ops.convert_to_tensor( 4600 stride, dtype=dtypes.int64, name="stride") 4601 self._drop_remainder = ops.convert_to_tensor( 4602 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 4603 self._structure = nest.pack_sequence_as( 4604 get_legacy_output_classes(input_dataset), [ 4605 DatasetSpec( # pylint: disable=g-complex-comprehension 4606 structure.convert_legacy_structure( 4607 output_type, output_shape, output_class)) 4608 for output_class, output_shape, output_type in zip( 4609 nest.flatten(get_legacy_output_classes(input_dataset)), 4610 nest.flatten(get_legacy_output_shapes(input_dataset)), 4611 nest.flatten(get_legacy_output_types(input_dataset))) 4612 ]) 4613 variant_tensor = gen_dataset_ops.window_dataset( 4614 input_dataset._variant_tensor, # pylint: disable=protected-access 4615 self._size, 4616 self._shift, 4617 self._stride, 4618 self._drop_remainder, 4619 **self._flat_structure) 4620 super(WindowDataset, self).__init__(input_dataset, variant_tensor) 4621 4622 @property 4623 def element_spec(self): 4624 return self._structure 4625 4626 4627class _OptionsDataset(UnaryUnchangedStructureDataset): 4628 """An identity `Dataset` that stores options.""" 4629 4630 def __init__(self, input_dataset, options): 4631 self._input_dataset = input_dataset 4632 variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access 4633 super(_OptionsDataset, self).__init__(input_dataset, variant_tensor) 4634 4635 if self._options_attr: 4636 self._options_attr = self._options_attr.merge(options) 4637 else: 4638 self._options_attr = options 4639 4640 def options(self): 4641 return self._options_attr 4642 4643 4644class _ModelDataset(UnaryUnchangedStructureDataset): 4645 """A `Dataset` that acts as an identity, and models performance.""" 4646 4647 def __init__(self, input_dataset, algorithm, cpu_budget, ram_budget): 4648 self._input_dataset = input_dataset 4649 variant_tensor = gen_dataset_ops.model_dataset( 4650 input_dataset._variant_tensor, # pylint: disable=protected-access 4651 algorithm=algorithm.value, 4652 cpu_budget=cpu_budget, 4653 ram_budget=ram_budget, 4654 **self._flat_structure) 4655 super(_ModelDataset, self).__init__(input_dataset, variant_tensor) 4656 4657 4658class _OptimizeDataset(UnaryUnchangedStructureDataset): 4659 """A `Dataset` that acts as an identity, and applies optimizations.""" 4660 4661 def __init__(self, 4662 input_dataset, 4663 optimizations_enabled, 4664 optimizations_disabled, 4665 optimizations_default, 4666 optimization_configs=None): 4667 self._input_dataset = input_dataset 4668 if optimization_configs is None: 4669 optimization_configs = [] 4670 4671 # We sort the options here before embedding as constant tensors to ensure 4672 # that serialization to NodeDef is determinstic. 4673 if optimizations_enabled: 4674 optimizations_enabled.sort() 4675 if optimizations_disabled: 4676 optimizations_disabled.sort() 4677 if optimizations_default: 4678 optimizations_default.sort() 4679 4680 self._optimizations_enabled = convert.optional_param_to_tensor( 4681 argument_name="optimizations_enabled", 4682 argument_value=optimizations_enabled, 4683 argument_default=[], 4684 argument_dtype=dtypes.string) 4685 self._optimizations_disabled = convert.optional_param_to_tensor( 4686 argument_name="optimizations_disabled", 4687 argument_value=optimizations_disabled, 4688 argument_default=[], 4689 argument_dtype=dtypes.string) 4690 self._optimizations_default = convert.optional_param_to_tensor( 4691 argument_name="optimizations_default", 4692 argument_value=optimizations_default, 4693 argument_default=[], 4694 argument_dtype=dtypes.string) 4695 4696 variant_tensor = gen_dataset_ops.optimize_dataset_v2( 4697 input_dataset._variant_tensor, # pylint: disable=protected-access 4698 self._optimizations_enabled, 4699 self._optimizations_disabled, 4700 self._optimizations_default, 4701 optimization_configs=optimization_configs, 4702 **self._flat_structure) 4703 4704 super(_OptimizeDataset, self).__init__(input_dataset, variant_tensor) 4705 4706 4707class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset): 4708 """A `Dataset` that acts as an identity, and sets a stats aggregator.""" 4709 4710 def __init__(self, input_dataset, aggregator, prefix, counter_prefix): 4711 self._input_dataset = input_dataset 4712 self._stats_aggregator = aggregator 4713 self._prefix = prefix 4714 self._counter_prefix = counter_prefix 4715 variant_tensor = ged_ops.set_stats_aggregator_dataset( 4716 input_dataset._variant_tensor, # pylint: disable=protected-access 4717 self._stats_aggregator._resource, # pylint: disable=protected-access 4718 self._prefix, 4719 self._counter_prefix, 4720 **self._flat_structure) 4721 super(_SetStatsAggregatorDataset, self).__init__(input_dataset, 4722 variant_tensor) 4723 4724 4725class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset): 4726 """A `Dataset` that acts as an identity, overriding intra-op parallelism.""" 4727 4728 def __init__(self, input_dataset, max_intra_op_parallelism): 4729 self._input_dataset = input_dataset 4730 self._max_intra_op_parallelism = ops.convert_to_tensor( 4731 max_intra_op_parallelism, 4732 dtype=dtypes.int64, 4733 name="max_intra_op_parallelism") 4734 variant_tensor = ged_ops.max_intra_op_parallelism_dataset( 4735 input_dataset._variant_tensor, # pylint: disable=protected-access 4736 self._max_intra_op_parallelism, 4737 **self._flat_structure) 4738 super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset, 4739 variant_tensor) 4740 4741 4742class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset): 4743 """A `Dataset` that acts as an identity, setting a private threadpool.""" 4744 4745 def __init__(self, input_dataset, num_threads): 4746 self._input_dataset = input_dataset 4747 self._num_threads = ops.convert_to_tensor( 4748 num_threads, dtype=dtypes.int64, name="num_threads") 4749 variant_tensor = ged_ops.private_thread_pool_dataset( 4750 input_dataset._variant_tensor, # pylint: disable=protected-access 4751 self._num_threads, 4752 **self._flat_structure) 4753 super(_PrivateThreadPoolDataset, self).__init__(input_dataset, 4754 variant_tensor) 4755 4756 4757def normalize_to_dense(dataset): 4758 """Normalizes non-tensor components in a dataset to dense representations. 4759 4760 This is necessary for dataset transformations that slice along the batch 4761 dimension and are oblivious to non-tensors, e.g. `unbatch`, `rebatch`. 4762 4763 Args: 4764 dataset: Dataset to normalize. 4765 4766 Returns: 4767 A dataset whose sparse and ragged tensors have been normalized to their 4768 dense representations. 4769 """ 4770 4771 # NOTE(mrry): This leads to a somewhat inefficient re-encoding step for all 4772 # non-tensor components. 4773 # 4774 # TODO(mrry): Consider optimizing this if it turns out to be a bottleneck. 4775 if _should_unpack_args(dataset.element_spec): 4776 def normalize(*args): 4777 return structure.to_batched_tensor_list(dataset.element_spec, tuple(args)) 4778 else: 4779 def normalize(arg): 4780 return structure.to_batched_tensor_list(dataset.element_spec, arg) 4781 4782 normalized_dataset = dataset.map(normalize) 4783 4784 # NOTE(mrry): Our `map()` has lost information about the structure of 4785 # non-tensor components, so re-apply the structure of the original dataset. 4786 return _RestructuredDataset(normalized_dataset, dataset.element_spec) 4787 4788 4789class _RestructuredDataset(UnaryDataset): 4790 """An internal helper for changing the structure and shape of a dataset.""" 4791 4792 def __init__(self, dataset, structure): 4793 self._input_dataset = dataset 4794 self._structure = structure 4795 4796 variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access 4797 super(_RestructuredDataset, self).__init__(dataset, variant_tensor) 4798 4799 @property 4800 def element_spec(self): 4801 return self._structure 4802 4803 4804class _UnbatchDataset(UnaryDataset): 4805 """A dataset that splits the elements of its input into multiple elements.""" 4806 4807 def __init__(self, input_dataset): 4808 """See `unbatch()` for more details.""" 4809 flat_shapes = input_dataset._flat_shapes # pylint: disable=protected-access 4810 if any(s.ndims == 0 for s in flat_shapes): 4811 raise ValueError("Cannot unbatch an input with scalar components.") 4812 known_batch_dim = tensor_shape.Dimension(None) 4813 for s in flat_shapes: 4814 try: 4815 known_batch_dim = known_batch_dim.merge_with(s[0]) 4816 except ValueError: 4817 raise ValueError("Cannot unbatch an input whose components have " 4818 "different batch sizes.") 4819 self._input_dataset = input_dataset 4820 self._structure = nest.map_structure( 4821 lambda component_spec: component_spec._unbatch(), # pylint: disable=protected-access 4822 get_structure(input_dataset)) 4823 variant_tensor = ged_ops.unbatch_dataset( 4824 self._input_dataset._variant_tensor, # pylint: disable=protected-access 4825 **self._flat_structure) 4826 super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor) 4827 4828 @property 4829 def element_spec(self): 4830 return self._structure 4831 4832 4833def _collect_resource_inputs(op): 4834 """Collects resource inputs for the given ops (and its variant inputs).""" 4835 4836 def _process(op_queue, seen_ops): 4837 """Processes the next element of the op queue. 4838 4839 Args: 4840 op_queue: Queue of Dataset operations to process. 4841 seen_ops: Already processed set of Operations. 4842 4843 Returns: 4844 A 2-tuple containing sets of resource handles. The first tuple entry 4845 contains read-only handles and the second entry contains read-write 4846 handles. 4847 """ 4848 4849 reads = [] 4850 writes = [] 4851 op = op_queue.pop() 4852 if op in seen_ops: 4853 return reads, writes 4854 seen_ops.add(op) 4855 # TODO(b/150139257): All resource inputs are in writes right now since we 4856 # have not updated the functional ops to set the special attribute that ACD 4857 # uses to figure out which of the op's inputs are read-only. 4858 reads, writes = acd_utils.get_read_write_resource_inputs(op) 4859 # Conservatively assume that any variant inputs are datasets. 4860 op_queue.extend(t.op for t in op.inputs if t.dtype == dtypes.variant) 4861 return reads, writes 4862 4863 op_queue = [op] 4864 seen_ops = set() 4865 all_reads = [] 4866 all_writes = [] 4867 while op_queue: 4868 reads, writes = _process(op_queue, seen_ops) 4869 all_reads.extend(reads) 4870 all_writes.extend(writes) 4871 4872 return all_reads, all_writes 4873 4874 4875@auto_control_deps.register_acd_resource_resolver 4876def _resource_resolver(op, resource_reads, resource_writes): 4877 """Updates resource inputs for tf.data ops with indirect dependencies.""" 4878 4879 updated = False 4880 if op.type in [ 4881 "DatasetToSingleElement", "DatasetToTFRecord", "ReduceDataset" 4882 ]: 4883 reads, writes = _collect_resource_inputs(op) 4884 for inp in reads: 4885 if inp not in resource_reads: 4886 updated = True 4887 resource_reads.add(inp) 4888 for inp in writes: 4889 if inp not in resource_writes: 4890 updated = True 4891 resource_writes.add(inp) 4892 4893 if op.type in [ 4894 "IteratorGetNext", "IteratorGetNextSync", "IteratorGetNextAsOptional" 4895 ]: 4896 iterator_resource = op.inputs[0] 4897 make_iterator_ops = [ 4898 op for op in iterator_resource.consumers() if op.type == "MakeIterator" 4899 ] 4900 4901 if len(make_iterator_ops) == 1: 4902 reads, writes = _collect_resource_inputs(make_iterator_ops[0]) 4903 for inp in reads: 4904 if inp not in resource_reads: 4905 updated = True 4906 resource_reads.add(inp) 4907 for inp in writes: 4908 if inp not in resource_writes: 4909 updated = True 4910 resource_writes.add(inp) 4911 4912 return updated 4913