1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for Datasets.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import functools 22import multiprocessing 23import sys 24import threading 25import warnings 26import weakref 27 28import numpy as np 29import six 30from six.moves import queue as Queue # pylint: disable=redefined-builtin 31 32from tensorflow.core.framework import dataset_options_pb2 33from tensorflow.core.framework import graph_pb2 34from tensorflow.python import tf2 35from tensorflow.python.data.ops import iterator_ops 36from tensorflow.python.data.ops import options as options_lib 37from tensorflow.python.data.util import nest 38from tensorflow.python.data.util import random_seed 39from tensorflow.python.data.util import structure 40from tensorflow.python.data.util import traverse 41from tensorflow.python.eager import context 42from tensorflow.python.eager import def_function 43from tensorflow.python.eager import function as eager_function 44from tensorflow.python.framework import auto_control_deps 45from tensorflow.python.framework import auto_control_deps_utils as acd_utils 46from tensorflow.python.framework import composite_tensor 47from tensorflow.python.framework import constant_op 48from tensorflow.python.framework import dtypes 49from tensorflow.python.framework import function 50from tensorflow.python.framework import ops 51from tensorflow.python.framework import random_seed as core_random_seed 52from tensorflow.python.framework import smart_cond 53from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 54from tensorflow.python.framework import tensor_shape 55from tensorflow.python.framework import tensor_spec 56from tensorflow.python.framework import tensor_util 57from tensorflow.python.framework import type_spec 58from tensorflow.python.ops import array_ops 59from tensorflow.python.ops import check_ops 60from tensorflow.python.ops import control_flow_ops 61from tensorflow.python.ops import gen_dataset_ops 62from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 63from tensorflow.python.ops import gen_io_ops 64from tensorflow.python.ops import logging_ops 65from tensorflow.python.ops import math_ops 66from tensorflow.python.ops import random_ops 67from tensorflow.python.ops import script_ops 68from tensorflow.python.ops import string_ops 69from tensorflow.python.ops.ragged import ragged_tensor 70from tensorflow.python.training.tracking import base as tracking_base 71from tensorflow.python.training.tracking import tracking 72from tensorflow.python.util import deprecation 73from tensorflow.python.util import function_utils 74from tensorflow.python.util import lazy_loader 75from tensorflow.python.util import nest as tf_nest 76from tensorflow.python.util.compat import collections_abc 77from tensorflow.python.util.tf_export import tf_export 78 79# Loaded lazily due to a circular dependency (roughly 80# tf.function->wrap_function->dataset->autograph->tf.function). 81# TODO(b/133251390): Use a regular import. 82wrap_function = lazy_loader.LazyLoader( 83 "wrap_function", globals(), 84 "tensorflow.python.eager.wrap_function") 85# TODO(mdan): Create a public API for this. 86autograph_ctx = lazy_loader.LazyLoader( 87 "autograph_ctx", globals(), 88 "tensorflow.python.autograph.core.ag_ctx") 89autograph = lazy_loader.LazyLoader( 90 "autograph", globals(), 91 "tensorflow.python.autograph.impl.api") 92# Loaded lazily due to a circular dependency 93# dataset_ops->interleave_ops->dataset_ops 94# TODO(aaudibert): Switch to the core sample_from_datasets after it is migrated 95# out of experimental. Then we can remove this lazy loading. 96interleave_ops = lazy_loader.LazyLoader( 97 "interleave_ops", globals(), 98 "tensorflow.python.data.experimental.ops.interleave_ops" 99) 100 101ops.NotDifferentiable("ReduceDataset") 102 103# A constant that can be used to enable auto-tuning. 104AUTOTUNE = -1 105tf_export("data.AUTOTUNE").export_constant(__name__, "AUTOTUNE") 106# TODO(b/168128531): Deprecate and remove this symbol. 107tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE") 108 109# Constants representing infinite and unknown cardinalities. 110INFINITE = -1 111UNKNOWN = -2 112tf_export("data.INFINITE_CARDINALITY").export_constant(__name__, "INFINITE") 113tf_export("data.UNKNOWN_CARDINALITY").export_constant(__name__, "UNKNOWN") 114 115 116@tf_export("data.Dataset", v1=[]) 117@six.add_metaclass(abc.ABCMeta) 118class DatasetV2(collections_abc.Iterable, tracking_base.Trackable, 119 composite_tensor.CompositeTensor): 120 """Represents a potentially large set of elements. 121 122 The `tf.data.Dataset` API supports writing descriptive and efficient input 123 pipelines. `Dataset` usage follows a common pattern: 124 125 1. Create a source dataset from your input data. 126 2. Apply dataset transformations to preprocess the data. 127 3. Iterate over the dataset and process the elements. 128 129 Iteration happens in a streaming fashion, so the full dataset does not need to 130 fit into memory. 131 132 Source Datasets: 133 134 The simplest way to create a dataset is to create it from a python `list`: 135 136 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 137 >>> for element in dataset: 138 ... print(element) 139 tf.Tensor(1, shape=(), dtype=int32) 140 tf.Tensor(2, shape=(), dtype=int32) 141 tf.Tensor(3, shape=(), dtype=int32) 142 143 To process lines from files, use `tf.data.TextLineDataset`: 144 145 >>> dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"]) 146 147 To process records written in the `TFRecord` format, use `TFRecordDataset`: 148 149 >>> dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"]) 150 151 To create a dataset of all files matching a pattern, use 152 `tf.data.Dataset.list_files`: 153 154 ```python 155 dataset = tf.data.Dataset.list_files("/path/*.txt") 156 ``` 157 158 See `tf.data.FixedLengthRecordDataset` and `tf.data.Dataset.from_generator` 159 for more ways to create datasets. 160 161 Transformations: 162 163 Once you have a dataset, you can apply transformations to prepare the data for 164 your model: 165 166 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 167 >>> dataset = dataset.map(lambda x: x*2) 168 >>> list(dataset.as_numpy_iterator()) 169 [2, 4, 6] 170 171 Common Terms: 172 173 **Element**: A single output from calling `next()` on a dataset iterator. 174 Elements may be nested structures containing multiple components. For 175 example, the element `(1, (3, "apple"))` has one tuple nested in another 176 tuple. The components are `1`, `3`, and `"apple"`. 177 178 **Component**: The leaf in the nested structure of an element. 179 180 Supported types: 181 182 Elements can be nested structures of tuples, named tuples, and dictionaries. 183 Note that Python lists are *not* treated as nested structures of components. 184 Instead, lists are converted to tensors and treated as components. For 185 example, the element `(1, [1, 2, 3])` has only two components; the tensor `1` 186 and the tensor `[1, 2, 3]`. Element components can be of any type 187 representable by `tf.TypeSpec`, including `tf.Tensor`, `tf.data.Dataset`, 188 `tf.sparse.SparseTensor`, `tf.RaggedTensor`, and `tf.TensorArray`. 189 190 ```python 191 a = 1 # Integer element 192 b = 2.0 # Float element 193 c = (1, 2) # Tuple element with 2 components 194 d = {"a": (2, 2), "b": 3} # Dict element with 3 components 195 Point = collections.namedtuple("Point", ["x", "y"]) 196 e = Point(1, 2) # Named tuple 197 f = tf.data.Dataset.range(10) # Dataset element 198 ``` 199 200 For more information, 201 read [this guide](https://www.tensorflow.org/guide/data). 202 """ 203 204 def __init__(self, variant_tensor): 205 """Creates a DatasetV2 object. 206 207 This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not 208 take anything in its constructor whereas in the DatasetV2, we expect 209 subclasses to create a variant_tensor and pass it in to the super() call. 210 211 Args: 212 variant_tensor: A DT_VARIANT tensor that represents the dataset. 213 """ 214 self._variant_tensor_attr = variant_tensor 215 weak_self = weakref.proxy(self) 216 self._variant_tracker = self._track_trackable( 217 _VariantTracker( 218 self._variant_tensor, 219 # _trace_variant_creation only works when executing eagerly, so we 220 # don't want to run it immediately. We also want the _VariantTracker 221 # to have a weak reference to the Dataset to avoid creating 222 # reference cycles and making work for the garbage collector. 223 lambda: weak_self._trace_variant_creation()()), # pylint: disable=unnecessary-lambda,protected-access 224 name="_variant_tracker") 225 self._graph_attr = ops.get_default_graph() 226 227 # Initialize the options for this dataset and its inputs. 228 self._options_attr = options_lib.Options() 229 for input_dataset in self._inputs(): 230 input_options = None 231 if isinstance(input_dataset, DatasetV1): 232 # If the V1 dataset does not have the `_dataset` attribute, we assume it 233 # is a dataset source and hence does not have options. Otherwise, we 234 # grab the options of `_dataset` object 235 if hasattr(input_dataset, "_dataset"): 236 if not isinstance(input_dataset._dataset, DatasetV2): 237 raise AssertionError( 238 "The input_dataset._dataset of dataset %s should be DatasetV2." 239 % type(self)) 240 input_options = input_dataset._dataset._options_attr 241 elif isinstance(input_dataset, DatasetV2): 242 input_options = input_dataset._options_attr 243 else: 244 raise TypeError("Unexpected dataset type: ", type(input_dataset)) 245 if input_options is not None: 246 self._options_attr = self._options_attr.merge(input_options) 247 self._options_attr._set_mutable(False) # pylint: disable=protected-access 248 249 @property 250 def _variant_tensor(self): 251 return self._variant_tensor_attr 252 253 @_variant_tensor.setter 254 def _variant_tensor(self, _): 255 raise ValueError("The _variant_tensor property is read-only") 256 257 @deprecation.deprecated_args(None, "Use external_state_policy instead", 258 "allow_stateful") 259 def _as_serialized_graph( 260 self, 261 allow_stateful=None, 262 strip_device_assignment=None, 263 external_state_policy=options_lib.ExternalStatePolicy.WARN): 264 """Produces serialized graph representation of the dataset. 265 266 Args: 267 allow_stateful: If true, we allow stateful ops to be present in the graph 268 def. In that case, the state in these ops would be thrown away. 269 strip_device_assignment: If true, non-local (i.e. job and task) device 270 assignment is stripped from ops in the serialized graph. 271 external_state_policy: The ExternalStatePolicy enum that determines how we 272 handle input pipelines that depend on external state. By default, its 273 set to WARN. 274 275 Returns: 276 A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a 277 serialized graph. 278 """ 279 if external_state_policy: 280 policy = external_state_policy.value 281 return gen_dataset_ops.dataset_to_graph_v2( 282 self._variant_tensor, 283 external_state_policy=policy, 284 strip_device_assignment=strip_device_assignment) 285 if strip_device_assignment: 286 return gen_dataset_ops.dataset_to_graph( 287 self._variant_tensor, 288 allow_stateful=allow_stateful, 289 strip_device_assignment=strip_device_assignment) 290 return gen_dataset_ops.dataset_to_graph( 291 self._variant_tensor, allow_stateful=allow_stateful) 292 293 def _trace_variant_creation(self): 294 """Traces a function which outputs a variant `tf.Tensor` for this dataset. 295 296 Note that creating this function involves evaluating an op, and is currently 297 only supported when executing eagerly. 298 299 Returns: 300 A zero-argument `ConcreteFunction` which outputs a variant `tf.Tensor`. 301 """ 302 variant = self._variant_tensor 303 if not isinstance(variant, ops.EagerTensor): 304 raise NotImplementedError( 305 "Can only export Datasets which were created executing eagerly. " 306 "Please file a feature request if this is important to you.") 307 with context.eager_mode(), ops.device("CPU"): 308 # pylint: disable=protected-access 309 graph_def = graph_pb2.GraphDef().FromString( 310 self._as_serialized_graph(external_state_policy=options_lib 311 .ExternalStatePolicy.FAIL).numpy()) 312 output_node_name = None 313 for node in graph_def.node: 314 if node.op == "_Retval": 315 if output_node_name is not None: 316 raise AssertionError( 317 "Found multiple return values from the dataset's graph, expected " 318 "only one.") 319 output_node_name, = node.input 320 if output_node_name is None: 321 raise AssertionError("Could not find the dataset's output node.") 322 # Add functions used in this Dataset to the function's graph, since they 323 # need to follow it around (and for example be added to a SavedModel which 324 # references the dataset). 325 variant_function = wrap_function.function_from_graph_def( 326 graph_def, inputs=[], outputs=output_node_name + ":0") 327 for used_function in self._functions(): 328 used_function.function.add_to_graph(variant_function.graph) 329 return variant_function 330 331 @abc.abstractmethod 332 def _inputs(self): 333 """Returns a list of the input datasets of the dataset.""" 334 335 raise NotImplementedError("Dataset._inputs") 336 337 @property 338 def _graph(self): 339 return self._graph_attr 340 341 @_graph.setter 342 def _graph(self, _): 343 raise ValueError("The _graph property is read-only") 344 345 # TODO(jsimsa): Change this to be the transitive closure of functions used 346 # by this dataset and its inputs. 347 def _functions(self): 348 """Returns a list of functions associated with this dataset. 349 350 Returns: 351 A list of `StructuredFunctionWrapper` objects. 352 """ 353 return [] 354 355 def _options(self): 356 """Returns the options tensor for this dataset.""" 357 # pylint: disable=protected-access 358 return gen_dataset_ops.get_options(self._variant_tensor) 359 360 @classmethod 361 def _options_tensor_to_options(cls, serialized_options): 362 """Converts options tensor to tf.data.Options object.""" 363 options = options_lib.Options() 364 if tensor_util.constant_value(serialized_options) is not None: 365 pb = dataset_options_pb2.Options.FromString(tensor_util.constant_value( 366 serialized_options)) 367 options._from_proto(pb) # pylint: disable=protected-access 368 return options 369 370 def options(self): 371 """Returns the options for this dataset and its inputs. 372 373 Returns: 374 A `tf.data.Options` object representing the dataset options. 375 """ 376 if context.executing_eagerly(): 377 options = self._options_tensor_to_options(self._options()) 378 options._set_mutable(False) # pylint: disable=protected-access 379 return options 380 warnings.warn("To make it possible to preserve tf.data options across " 381 "serialization boundaries, their implementation has moved to " 382 "be part of the TensorFlow graph. As a consequence, the " 383 "options value is in general no longer known at graph " 384 "construction time. Invoking this method in graph mode " 385 "retains the legacy behavior of the original implementation, " 386 "but note that the returned value might not reflect the " 387 "actual value of the options.") 388 return self._options_attr 389 390 def _apply_debug_options(self): 391 if DEBUG_MODE: 392 # Disable autotuning and static optimizations that could introduce 393 # parallelism or asynchrony. 394 options = options_lib.Options() 395 options.autotune.enabled = False 396 options.experimental_optimization.map_and_batch_fusion = False 397 options.experimental_optimization.map_parallelization = False 398 dataset = _OptionsDataset(self, options) 399 else: 400 dataset = self 401 402 return dataset 403 404 def __iter__(self): 405 """Creates an iterator for elements of this dataset. 406 407 The returned iterator implements the Python Iterator protocol. 408 409 Returns: 410 An `tf.data.Iterator` for the elements of this dataset. 411 412 Raises: 413 RuntimeError: If not inside of tf.function and not executing eagerly. 414 """ 415 if context.executing_eagerly() or ops.inside_function(): 416 with ops.colocate_with(self._variant_tensor): 417 return iterator_ops.OwnedIterator(self) 418 else: 419 raise RuntimeError("__iter__() is only supported inside of tf.function " 420 "or when eager execution is enabled.") 421 422 def __bool__(self): 423 return True # Required as __len__ is defined 424 425 __nonzero__ = __bool__ # Python 2 backward compatibility 426 427 def __len__(self): 428 """Returns the length of the dataset if it is known and finite. 429 430 This method requires that you are running in eager mode, and that the 431 length of the dataset is known and non-infinite. When the length may be 432 unknown or infinite, or if you are running in graph mode, use 433 `tf.data.Dataset.cardinality` instead. 434 435 Returns: 436 An integer representing the length of the dataset. 437 438 Raises: 439 RuntimeError: If the dataset length is unknown or infinite, or if eager 440 execution is not enabled. 441 """ 442 if not context.executing_eagerly(): 443 raise TypeError("__len__() is not supported while tracing functions. " 444 "Use `tf.data.Dataset.cardinality` instead.") 445 length = self.cardinality() 446 if length.numpy() == INFINITE: 447 raise TypeError("dataset length is infinite.") 448 if length.numpy() == UNKNOWN: 449 raise TypeError("dataset length is unknown.") 450 return length 451 452 @abc.abstractproperty 453 def element_spec(self): 454 """The type specification of an element of this dataset. 455 456 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 457 >>> dataset.element_spec 458 TensorSpec(shape=(), dtype=tf.int32, name=None) 459 460 For more information, 461 read [this guide](https://www.tensorflow.org/guide/data#dataset_structure). 462 463 Returns: 464 A (nested) structure of `tf.TypeSpec` objects matching the structure of an 465 element of this dataset and specifying the type of individual components. 466 """ 467 raise NotImplementedError("Dataset.element_spec") 468 469 def __repr__(self): 470 output_shapes = nest.map_structure(str, get_legacy_output_shapes(self)) 471 output_shapes = str(output_shapes).replace("'", "") 472 output_types = nest.map_structure(repr, get_legacy_output_types(self)) 473 output_types = str(output_types).replace("'", "") 474 return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes, 475 output_types)) 476 477 def as_numpy_iterator(self): 478 """Returns an iterator which converts all elements of the dataset to numpy. 479 480 Use `as_numpy_iterator` to inspect the content of your dataset. To see 481 element shapes and types, print dataset elements directly instead of using 482 `as_numpy_iterator`. 483 484 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 485 >>> for element in dataset: 486 ... print(element) 487 tf.Tensor(1, shape=(), dtype=int32) 488 tf.Tensor(2, shape=(), dtype=int32) 489 tf.Tensor(3, shape=(), dtype=int32) 490 491 This method requires that you are running in eager mode and the dataset's 492 element_spec contains only `TensorSpec` components. 493 494 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 495 >>> for element in dataset.as_numpy_iterator(): 496 ... print(element) 497 1 498 2 499 3 500 501 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 502 >>> print(list(dataset.as_numpy_iterator())) 503 [1, 2, 3] 504 505 `as_numpy_iterator()` will preserve the nested structure of dataset 506 elements. 507 508 >>> dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]), 509 ... 'b': [5, 6]}) 510 >>> list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5}, 511 ... {'a': (2, 4), 'b': 6}] 512 True 513 514 Returns: 515 An iterable over the elements of the dataset, with their tensors converted 516 to numpy arrays. 517 518 Raises: 519 TypeError: if an element contains a non-`Tensor` value. 520 RuntimeError: if eager execution is not enabled. 521 """ 522 if not context.executing_eagerly(): 523 raise RuntimeError("as_numpy_iterator() is not supported while tracing " 524 "functions") 525 for component_spec in nest.flatten(self.element_spec): 526 if not isinstance( 527 component_spec, 528 (tensor_spec.TensorSpec, ragged_tensor.RaggedTensorSpec)): 529 raise TypeError( 530 "Dataset.as_numpy_iterator() does not support datasets containing " 531 + str(component_spec.value_type)) 532 533 return _NumpyIterator(self) 534 535 @property 536 def _flat_shapes(self): 537 """Returns a list `tf.TensorShapes`s for the element tensor representation. 538 539 Returns: 540 A list `tf.TensorShapes`s for the element tensor representation. 541 """ 542 return structure.get_flat_tensor_shapes(self.element_spec) 543 544 @property 545 def _flat_types(self): 546 """Returns a list `tf.DType`s for the element tensor representation. 547 548 Returns: 549 A list `tf.DType`s for the element tensor representation. 550 """ 551 return structure.get_flat_tensor_types(self.element_spec) 552 553 @property 554 def _flat_structure(self): 555 """Helper for setting `output_shapes` and `output_types` attrs of an op. 556 557 Most dataset op constructors expect `output_shapes` and `output_types` 558 arguments that represent the flattened structure of an element. This helper 559 function generates these attrs as a keyword argument dictionary, allowing 560 `Dataset._variant_tensor` implementations to pass `**self._flat_structure` 561 to the op constructor. 562 563 Returns: 564 A dictionary of keyword arguments that can be passed to a dataset op 565 constructor. 566 """ 567 return { 568 "output_shapes": self._flat_shapes, 569 "output_types": self._flat_types, 570 } 571 572 @property 573 def _type_spec(self): 574 return DatasetSpec(self.element_spec) 575 576 @staticmethod 577 def from_tensors(tensors): 578 """Creates a `Dataset` with a single element, comprising the given tensors. 579 580 `from_tensors` produces a dataset containing only a single element. To slice 581 the input tensor into multiple elements, use `from_tensor_slices` instead. 582 583 >>> dataset = tf.data.Dataset.from_tensors([1, 2, 3]) 584 >>> list(dataset.as_numpy_iterator()) 585 [array([1, 2, 3], dtype=int32)] 586 >>> dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A')) 587 >>> list(dataset.as_numpy_iterator()) 588 [(array([1, 2, 3], dtype=int32), b'A')] 589 590 >>> # You can use `from_tensors` to produce a dataset which repeats 591 >>> # the same example many times. 592 >>> example = tf.constant([1,2,3]) 593 >>> dataset = tf.data.Dataset.from_tensors(example).repeat(2) 594 >>> list(dataset.as_numpy_iterator()) 595 [array([1, 2, 3], dtype=int32), array([1, 2, 3], dtype=int32)] 596 597 Note that if `tensors` contains a NumPy array, and eager execution is not 598 enabled, the values will be embedded in the graph as one or more 599 `tf.constant` operations. For large datasets (> 1 GB), this can waste 600 memory and run into byte limits of graph serialization. If `tensors` 601 contains one or more large NumPy arrays, consider the alternative described 602 in [this 603 guide](https://tensorflow.org/guide/data#consuming_numpy_arrays). 604 605 Args: 606 tensors: A dataset "element". Supported values are documented 607 [here](https://www.tensorflow.org/guide/data#dataset_structure). 608 609 Returns: 610 Dataset: A `Dataset`. 611 """ 612 return TensorDataset(tensors) 613 614 @staticmethod 615 def from_tensor_slices(tensors): 616 """Creates a `Dataset` whose elements are slices of the given tensors. 617 618 The given tensors are sliced along their first dimension. This operation 619 preserves the structure of the input tensors, removing the first dimension 620 of each tensor and using it as the dataset dimension. All input tensors 621 must have the same size in their first dimensions. 622 623 >>> # Slicing a 1D tensor produces scalar tensor elements. 624 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 625 >>> list(dataset.as_numpy_iterator()) 626 [1, 2, 3] 627 628 >>> # Slicing a 2D tensor produces 1D tensor elements. 629 >>> dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]]) 630 >>> list(dataset.as_numpy_iterator()) 631 [array([1, 2], dtype=int32), array([3, 4], dtype=int32)] 632 633 >>> # Slicing a tuple of 1D tensors produces tuple elements containing 634 >>> # scalar tensors. 635 >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6])) 636 >>> list(dataset.as_numpy_iterator()) 637 [(1, 3, 5), (2, 4, 6)] 638 639 >>> # Dictionary structure is also preserved. 640 >>> dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]}) 641 >>> list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3}, 642 ... {'a': 2, 'b': 4}] 643 True 644 645 >>> # Two tensors can be combined into one Dataset object. 646 >>> features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor 647 >>> labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor 648 >>> dataset = Dataset.from_tensor_slices((features, labels)) 649 >>> # Both the features and the labels tensors can be converted 650 >>> # to a Dataset object separately and combined after. 651 >>> features_dataset = Dataset.from_tensor_slices(features) 652 >>> labels_dataset = Dataset.from_tensor_slices(labels) 653 >>> dataset = Dataset.zip((features_dataset, labels_dataset)) 654 >>> # A batched feature and label set can be converted to a Dataset 655 >>> # in similar fashion. 656 >>> batched_features = tf.constant([[[1, 3], [2, 3]], 657 ... [[2, 1], [1, 2]], 658 ... [[3, 3], [3, 2]]], shape=(3, 2, 2)) 659 >>> batched_labels = tf.constant([['A', 'A'], 660 ... ['B', 'B'], 661 ... ['A', 'B']], shape=(3, 2, 1)) 662 >>> dataset = Dataset.from_tensor_slices((batched_features, batched_labels)) 663 >>> for element in dataset.as_numpy_iterator(): 664 ... print(element) 665 (array([[1, 3], 666 [2, 3]], dtype=int32), array([[b'A'], 667 [b'A']], dtype=object)) 668 (array([[2, 1], 669 [1, 2]], dtype=int32), array([[b'B'], 670 [b'B']], dtype=object)) 671 (array([[3, 3], 672 [3, 2]], dtype=int32), array([[b'A'], 673 [b'B']], dtype=object)) 674 675 Note that if `tensors` contains a NumPy array, and eager execution is not 676 enabled, the values will be embedded in the graph as one or more 677 `tf.constant` operations. For large datasets (> 1 GB), this can waste 678 memory and run into byte limits of graph serialization. If `tensors` 679 contains one or more large NumPy arrays, consider the alternative described 680 in [this guide]( 681 https://tensorflow.org/guide/data#consuming_numpy_arrays). 682 683 Args: 684 tensors: A dataset element, whose components have the same first 685 dimension. Supported values are documented 686 [here](https://www.tensorflow.org/guide/data#dataset_structure). 687 688 Returns: 689 Dataset: A `Dataset`. 690 """ 691 return TensorSliceDataset(tensors) 692 693 class _GeneratorState(object): 694 """Stores outstanding iterators created from a Python generator. 695 696 This class keeps track of potentially multiple iterators that may have 697 been created from a generator, e.g. in the case that the dataset is 698 repeated, or nested within a parallel computation. 699 """ 700 701 def __init__(self, generator): 702 self._generator = generator 703 self._lock = threading.Lock() 704 self._next_id = 0 # GUARDED_BY(self._lock) 705 self._args = {} 706 self._iterators = {} 707 708 def get_next_id(self, *args): 709 with self._lock: 710 ret = self._next_id 711 self._next_id += 1 712 self._args[ret] = args 713 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit 714 # casting in `py_func()` will create an array of `np.int32` on Windows, 715 # leading to a runtime error. 716 return np.array(ret, dtype=np.int64) 717 718 def get_iterator(self, iterator_id): 719 try: 720 return self._iterators[iterator_id] 721 except KeyError: 722 iterator = iter(self._generator(*self._args.pop(iterator_id))) 723 self._iterators[iterator_id] = iterator 724 return iterator 725 726 def iterator_completed(self, iterator_id): 727 del self._iterators[iterator_id] 728 729 @staticmethod 730 @deprecation.deprecated_args(None, "Use output_signature instead", 731 "output_types", "output_shapes") 732 def from_generator(generator, 733 output_types=None, 734 output_shapes=None, 735 args=None, 736 output_signature=None): 737 """Creates a `Dataset` whose elements are generated by `generator`. 738 739 The `generator` argument must be a callable object that returns 740 an object that supports the `iter()` protocol (e.g. a generator function). 741 742 The elements generated by `generator` must be compatible with either the 743 given `output_signature` argument or with the given `output_types` and 744 (optionally) `output_shapes` arguments, whichever was specified. 745 746 The recommended way to call `from_generator` is to use the 747 `output_signature` argument. In this case the output will be assumed to 748 consist of objects with the classes, shapes and types defined by 749 `tf.TypeSpec` objects from `output_signature` argument: 750 751 >>> def gen(): 752 ... ragged_tensor = tf.ragged.constant([[1, 2], [3]]) 753 ... yield 42, ragged_tensor 754 >>> 755 >>> dataset = tf.data.Dataset.from_generator( 756 ... gen, 757 ... output_signature=( 758 ... tf.TensorSpec(shape=(), dtype=tf.int32), 759 ... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32))) 760 >>> 761 >>> list(dataset.take(1)) 762 [(<tf.Tensor: shape=(), dtype=int32, numpy=42>, 763 <tf.RaggedTensor [[1, 2], [3]]>)] 764 765 There is also a deprecated way to call `from_generator` by either with 766 `output_types` argument alone or together with `output_shapes` argument. 767 In this case the output of the function will be assumed to consist of 768 `tf.Tensor` objects with the types defined by `output_types` and with the 769 shapes which are either unknown or defined by `output_shapes`. 770 771 Note: The current implementation of `Dataset.from_generator()` uses 772 `tf.numpy_function` and inherits the same constraints. In particular, it 773 requires the dataset and iterator related operations to be placed 774 on a device in the same process as the Python program that called 775 `Dataset.from_generator()`. The body of `generator` will not be 776 serialized in a `GraphDef`, and you should not use this method if you 777 need to serialize your model and restore it in a different environment. 778 779 Note: If `generator` depends on mutable global variables or other external 780 state, be aware that the runtime may invoke `generator` multiple times 781 (in order to support repeating the `Dataset`) and at any time 782 between the call to `Dataset.from_generator()` and the production of the 783 first element from the generator. Mutating global variables or external 784 state can cause undefined behavior, and we recommend that you explicitly 785 cache any external state in `generator` before calling 786 `Dataset.from_generator()`. 787 788 Note: While the `output_signature` parameter makes it possible to yield 789 `Dataset` elements, the scope of `Dataset.from_generator()` should be 790 limited to logic that cannot be expressed through tf.data operations. Using 791 tf.data operations within the generator function is an anti-pattern and may 792 result in incremental memory growth. 793 794 Args: 795 generator: A callable object that returns an object that supports the 796 `iter()` protocol. If `args` is not specified, `generator` must take no 797 arguments; otherwise it must take as many arguments as there are values 798 in `args`. 799 output_types: (Optional.) A (nested) structure of `tf.DType` objects 800 corresponding to each component of an element yielded by `generator`. 801 output_shapes: (Optional.) A (nested) structure of `tf.TensorShape` 802 objects corresponding to each component of an element yielded by 803 `generator`. 804 args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated 805 and passed to `generator` as NumPy-array arguments. 806 output_signature: (Optional.) A (nested) structure of `tf.TypeSpec` 807 objects corresponding to each component of an element yielded by 808 `generator`. 809 810 Returns: 811 Dataset: A `Dataset`. 812 """ 813 if not callable(generator): 814 raise TypeError("`generator` must be callable.") 815 816 if output_signature is not None: 817 if output_types is not None: 818 raise TypeError("`output_types` can not be used together with " 819 "`output_signature`") 820 if output_shapes is not None: 821 raise TypeError("`output_shapes` can not be used together with " 822 "`output_signature`") 823 if not all( 824 isinstance(_, type_spec.TypeSpec) 825 for _ in nest.flatten(output_signature)): 826 raise TypeError("All the elements of `output_signature` must be " 827 "`tf.TypeSpec` objects.") 828 else: 829 if output_types is None: 830 raise TypeError("Either `output_signature` or `output_types` must " 831 "be specified") 832 833 if output_signature is None: 834 if output_shapes is None: 835 output_shapes = nest.map_structure( 836 lambda _: tensor_shape.TensorShape(None), output_types) 837 else: 838 output_shapes = nest.map_structure_up_to(output_types, 839 tensor_shape.as_shape, 840 output_shapes) 841 output_signature = nest.map_structure_up_to(output_types, 842 tensor_spec.TensorSpec, 843 output_shapes, output_types) 844 if all( 845 isinstance(x, tensor_spec.TensorSpec) 846 for x in nest.flatten(output_signature)): 847 output_types = nest.pack_sequence_as( 848 output_signature, [x.dtype for x in nest.flatten(output_signature)]) 849 output_shapes = nest.pack_sequence_as( 850 output_signature, [x.shape for x in nest.flatten(output_signature)]) 851 852 if args is None: 853 args = () 854 else: 855 args = tuple(ops.convert_n_to_tensor(args, name="args")) 856 857 generator_state = DatasetV2._GeneratorState(generator) 858 859 def get_iterator_id_fn(unused_dummy): 860 """Creates a unique `iterator_id` for each pass over the dataset. 861 862 The returned `iterator_id` disambiguates between multiple concurrently 863 existing iterators. 864 865 Args: 866 unused_dummy: Ignored value. 867 868 Returns: 869 A `tf.int64` tensor whose value uniquely identifies an iterator in 870 `generator_state`. 871 """ 872 return script_ops.numpy_function(generator_state.get_next_id, args, 873 dtypes.int64) 874 875 def generator_next_fn(iterator_id_t): 876 """Generates the next element from iterator with ID `iterator_id_t`. 877 878 We map this function across an infinite repetition of the 879 `iterator_id_t`, and raise `StopIteration` to terminate the iteration. 880 881 Args: 882 iterator_id_t: A `tf.int64` tensor whose value uniquely identifies the 883 iterator in `generator_state` from which to generate an element. 884 885 Returns: 886 The next element to generate from the iterator. 887 """ 888 if output_types and output_shapes: 889 flattened_types = [ 890 dtypes.as_dtype(dt) for dt in nest.flatten(output_types) 891 ] 892 flattened_shapes = nest.flatten(output_shapes) 893 894 def generator_py_func(iterator_id): 895 """A `py_func` that will be called to invoke the iterator.""" 896 # `next()` raises `StopIteration` when there are no more 897 # elements remaining to be generated. 898 values = next(generator_state.get_iterator(iterator_id)) 899 900 # Use the same _convert function from the py_func() implementation to 901 # convert the returned values to arrays early, so that we can inspect 902 # their values. 903 try: 904 flattened_values = nest.flatten_up_to(output_types, values) 905 except (TypeError, ValueError): 906 six.reraise( 907 TypeError, 908 TypeError( 909 "`generator` yielded an element that did not match the " 910 "expected structure. The expected structure was %s, but " 911 "the yielded element was %s." % (output_types, values)), 912 sys.exc_info()[2]) 913 ret_arrays = [] 914 for ret, dtype in zip(flattened_values, flattened_types): 915 try: 916 ret_arrays.append( 917 script_ops.FuncRegistry._convert( # pylint: disable=protected-access 918 ret, 919 dtype=dtype.as_numpy_dtype)) 920 except (TypeError, ValueError): 921 six.reraise( 922 TypeError, 923 TypeError( 924 "`generator` yielded an element that could not be " 925 "converted to the expected type. The expected type was " 926 "%s, but the yielded element was %s." % 927 (dtype.name, ret)), 928 sys.exc_info()[2]) 929 930 # Additional type and shape checking to ensure that the components of 931 # the generated element match the `output_types` and `output_shapes` 932 # arguments. 933 for (ret_array, expected_dtype, 934 expected_shape) in zip(ret_arrays, flattened_types, 935 flattened_shapes): 936 if ret_array.dtype != expected_dtype.as_numpy_dtype: 937 raise TypeError( 938 "`generator` yielded an element of type %s where an element " 939 "of type %s was expected." % 940 (ret_array.dtype, expected_dtype.as_numpy_dtype)) 941 if not expected_shape.is_compatible_with(ret_array.shape): 942 raise ValueError( 943 "`generator` yielded an element of shape %s where an element " 944 "of shape %s was expected." % 945 (ret_array.shape, expected_shape)) 946 947 return ret_arrays 948 949 flat_values = script_ops.numpy_function(generator_py_func, 950 [iterator_id_t], 951 flattened_types) 952 953 # The `py_func()` op drops the inferred shapes, so we add them back in 954 # here. 955 if output_shapes is not None: 956 for ret_t, shape in zip(flat_values, flattened_shapes): 957 ret_t.set_shape(shape) 958 959 return nest.pack_sequence_as(output_types, flat_values) 960 else: 961 flat_output_types = structure.get_flat_tensor_types(output_signature) 962 963 def generator_py_func(iterator_id): 964 """A `py_func` that will be called to invoke the iterator.""" 965 # `next()` raises `StopIteration` when there are no more 966 # elements remaining to be generated. 967 values = next(generator_state.get_iterator(iterator_id.numpy())) 968 969 try: 970 values = structure.normalize_element(values, output_signature) 971 except (TypeError, ValueError): 972 six.reraise( 973 TypeError, 974 TypeError( 975 "`generator` yielded an element that did not match the " 976 "expected structure. The expected structure was %s, but " 977 "the yielded element was %s." % (output_signature, values)), 978 sys.exc_info()[2]) 979 980 values_spec = structure.type_spec_from_value(values) 981 982 if not structure.are_compatible(values_spec, output_signature): 983 raise TypeError( 984 "`generator` yielded an element of %s where an element " 985 "of %s was expected." % (values_spec, output_signature)) 986 987 return structure.to_tensor_list(output_signature, values) 988 989 return script_ops._eager_py_func( # pylint: disable=protected-access 990 generator_py_func, 991 inp=[iterator_id_t], 992 Tout=flat_output_types, 993 use_tape_cache=False) 994 995 def finalize_fn(iterator_id_t): 996 """Releases host-side state for the iterator with ID `iterator_id_t`.""" 997 998 def finalize_py_func(iterator_id): 999 generator_state.iterator_completed(iterator_id) 1000 # We return a dummy value so that the `finalize_fn` has a valid 1001 # signature. 1002 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit 1003 # casting in `py_func()` will create an array of `np.int32` on Windows, 1004 # leading to a runtime error. 1005 return np.array(0, dtype=np.int64) 1006 1007 return script_ops.numpy_function(finalize_py_func, [iterator_id_t], 1008 dtypes.int64) 1009 1010 # This function associates each traversal of `generator` with a unique 1011 # iterator ID. 1012 def flat_map_fn(dummy_arg): 1013 # The `get_iterator_id_fn` gets a unique ID for the current instance of 1014 # of the generator. 1015 # The `generator_next_fn` gets the next element from the iterator with the 1016 # given ID, and raises StopIteration when that iterator contains no 1017 # more elements. 1018 return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn, 1019 finalize_fn, output_signature) 1020 1021 # A single-element dataset that, each time it is evaluated, contains a 1022 # freshly-generated and unique (for the returned dataset) int64 1023 # ID that will be used to identify the appropriate Python state, which 1024 # is encapsulated in `generator_state`, and captured in 1025 # `get_iterator_id_map_fn`. 1026 dummy = 0 1027 id_dataset = Dataset.from_tensors(dummy) 1028 1029 # A dataset that contains all of the elements generated by a 1030 # single iterator created from `generator`, identified by the 1031 # iterator ID contained in `id_dataset`. Lifting the iteration 1032 # into a flat_map here enables multiple repetitions and/or nested 1033 # versions of the returned dataset to be created, because it forces 1034 # the generation of a new ID for each version. 1035 return id_dataset.flat_map(flat_map_fn) 1036 1037 @staticmethod 1038 def range(*args, **kwargs): 1039 """Creates a `Dataset` of a step-separated range of values. 1040 1041 >>> list(Dataset.range(5).as_numpy_iterator()) 1042 [0, 1, 2, 3, 4] 1043 >>> list(Dataset.range(2, 5).as_numpy_iterator()) 1044 [2, 3, 4] 1045 >>> list(Dataset.range(1, 5, 2).as_numpy_iterator()) 1046 [1, 3] 1047 >>> list(Dataset.range(1, 5, -2).as_numpy_iterator()) 1048 [] 1049 >>> list(Dataset.range(5, 1).as_numpy_iterator()) 1050 [] 1051 >>> list(Dataset.range(5, 1, -2).as_numpy_iterator()) 1052 [5, 3] 1053 >>> list(Dataset.range(2, 5, output_type=tf.int32).as_numpy_iterator()) 1054 [2, 3, 4] 1055 >>> list(Dataset.range(1, 5, 2, output_type=tf.float32).as_numpy_iterator()) 1056 [1.0, 3.0] 1057 1058 Args: 1059 *args: follows the same semantics as python's xrange. 1060 len(args) == 1 -> start = 0, stop = args[0], step = 1. 1061 len(args) == 2 -> start = args[0], stop = args[1], step = 1. 1062 len(args) == 3 -> start = args[0], stop = args[1], step = args[2]. 1063 **kwargs: 1064 - output_type: Its expected dtype. (Optional, default: `tf.int64`). 1065 1066 Returns: 1067 Dataset: A `RangeDataset`. 1068 1069 Raises: 1070 ValueError: if len(args) == 0. 1071 """ 1072 return RangeDataset(*args, **kwargs) 1073 1074 @staticmethod 1075 def zip(datasets): 1076 """Creates a `Dataset` by zipping together the given datasets. 1077 1078 This method has similar semantics to the built-in `zip()` function 1079 in Python, with the main difference being that the `datasets` 1080 argument can be a (nested) structure of `Dataset` objects. The supported 1081 nesting mechanisms are documented 1082 [here] (https://www.tensorflow.org/guide/data#dataset_structure). 1083 1084 >>> # The nested structure of the `datasets` argument determines the 1085 >>> # structure of elements in the resulting dataset. 1086 >>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ] 1087 >>> b = tf.data.Dataset.range(4, 7) # ==> [ 4, 5, 6 ] 1088 >>> ds = tf.data.Dataset.zip((a, b)) 1089 >>> list(ds.as_numpy_iterator()) 1090 [(1, 4), (2, 5), (3, 6)] 1091 >>> ds = tf.data.Dataset.zip((b, a)) 1092 >>> list(ds.as_numpy_iterator()) 1093 [(4, 1), (5, 2), (6, 3)] 1094 >>> 1095 >>> # The `datasets` argument may contain an arbitrary number of datasets. 1096 >>> c = tf.data.Dataset.range(7, 13).batch(2) # ==> [ [7, 8], 1097 ... # [9, 10], 1098 ... # [11, 12] ] 1099 >>> ds = tf.data.Dataset.zip((a, b, c)) 1100 >>> for element in ds.as_numpy_iterator(): 1101 ... print(element) 1102 (1, 4, array([7, 8])) 1103 (2, 5, array([ 9, 10])) 1104 (3, 6, array([11, 12])) 1105 >>> 1106 >>> # The number of elements in the resulting dataset is the same as 1107 >>> # the size of the smallest dataset in `datasets`. 1108 >>> d = tf.data.Dataset.range(13, 15) # ==> [ 13, 14 ] 1109 >>> ds = tf.data.Dataset.zip((a, d)) 1110 >>> list(ds.as_numpy_iterator()) 1111 [(1, 13), (2, 14)] 1112 1113 Args: 1114 datasets: A (nested) structure of datasets. 1115 1116 Returns: 1117 Dataset: A `Dataset`. 1118 """ 1119 return ZipDataset(datasets) 1120 1121 def concatenate(self, dataset): 1122 """Creates a `Dataset` by concatenating the given dataset with this dataset. 1123 1124 >>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ] 1125 >>> b = tf.data.Dataset.range(4, 8) # ==> [ 4, 5, 6, 7 ] 1126 >>> ds = a.concatenate(b) 1127 >>> list(ds.as_numpy_iterator()) 1128 [1, 2, 3, 4, 5, 6, 7] 1129 >>> # The input dataset and dataset to be concatenated should have 1130 >>> # compatible element specs. 1131 >>> c = tf.data.Dataset.zip((a, b)) 1132 >>> a.concatenate(c) 1133 Traceback (most recent call last): 1134 TypeError: Two datasets to concatenate have different types 1135 <dtype: 'int64'> and (tf.int64, tf.int64) 1136 >>> d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"]) 1137 >>> a.concatenate(d) 1138 Traceback (most recent call last): 1139 TypeError: Two datasets to concatenate have different types 1140 <dtype: 'int64'> and <dtype: 'string'> 1141 1142 Args: 1143 dataset: `Dataset` to be concatenated. 1144 1145 Returns: 1146 Dataset: A `Dataset`. 1147 """ 1148 return ConcatenateDataset(self, dataset) 1149 1150 def prefetch(self, buffer_size): 1151 """Creates a `Dataset` that prefetches elements from this dataset. 1152 1153 Most dataset input pipelines should end with a call to `prefetch`. This 1154 allows later elements to be prepared while the current element is being 1155 processed. This often improves latency and throughput, at the cost of 1156 using additional memory to store prefetched elements. 1157 1158 Note: Like other `Dataset` methods, prefetch operates on the 1159 elements of the input dataset. It has no concept of examples vs. batches. 1160 `examples.prefetch(2)` will prefetch two elements (2 examples), 1161 while `examples.batch(20).prefetch(2)` will prefetch 2 elements 1162 (2 batches, of 20 examples each). 1163 1164 >>> dataset = tf.data.Dataset.range(3) 1165 >>> dataset = dataset.prefetch(2) 1166 >>> list(dataset.as_numpy_iterator()) 1167 [0, 1, 2] 1168 1169 Args: 1170 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the maximum 1171 number of elements that will be buffered when prefetching. If the value 1172 `tf.data.AUTOTUNE` is used, then the buffer size is dynamically tuned. 1173 Returns: 1174 Dataset: A `Dataset`. 1175 """ 1176 if DEBUG_MODE: 1177 return self 1178 return PrefetchDataset(self, buffer_size) 1179 1180 @staticmethod 1181 def list_files(file_pattern, shuffle=None, seed=None): 1182 """A dataset of all files matching one or more glob patterns. 1183 1184 The `file_pattern` argument should be a small number of glob patterns. 1185 If your filenames have already been globbed, use 1186 `Dataset.from_tensor_slices(filenames)` instead, as re-globbing every 1187 filename with `list_files` may result in poor performance with remote 1188 storage systems. 1189 1190 Note: The default behavior of this method is to return filenames in 1191 a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False` 1192 to get results in a deterministic order. 1193 1194 Example: 1195 If we had the following files on our filesystem: 1196 1197 - /path/to/dir/a.txt 1198 - /path/to/dir/b.py 1199 - /path/to/dir/c.py 1200 1201 If we pass "/path/to/dir/*.py" as the directory, the dataset 1202 would produce: 1203 1204 - /path/to/dir/b.py 1205 - /path/to/dir/c.py 1206 1207 Args: 1208 file_pattern: A string, a list of strings, or a `tf.Tensor` of string type 1209 (scalar or vector), representing the filename glob (i.e. shell wildcard) 1210 pattern(s) that will be matched. 1211 shuffle: (Optional.) If `True`, the file names will be shuffled randomly. 1212 Defaults to `True`. 1213 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 1214 seed that will be used to create the distribution. See 1215 `tf.random.set_seed` for behavior. 1216 1217 Returns: 1218 Dataset: A `Dataset` of strings corresponding to file names. 1219 """ 1220 with ops.name_scope("list_files"): 1221 if shuffle is None: 1222 shuffle = True 1223 file_pattern = ops.convert_to_tensor( 1224 file_pattern, dtype=dtypes.string, name="file_pattern") 1225 matching_files = gen_io_ops.matching_files(file_pattern) 1226 1227 # Raise an exception if `file_pattern` does not match any files. 1228 condition = math_ops.greater(array_ops.shape(matching_files)[0], 0, 1229 name="match_not_empty") 1230 1231 message = math_ops.add( 1232 "No files matched pattern: ", 1233 string_ops.reduce_join(file_pattern, separator=", "), name="message") 1234 1235 assert_not_empty = control_flow_ops.Assert( 1236 condition, [message], summarize=1, name="assert_not_empty") 1237 with ops.control_dependencies([assert_not_empty]): 1238 matching_files = array_ops.identity(matching_files) 1239 1240 dataset = Dataset.from_tensor_slices(matching_files) 1241 if shuffle: 1242 # NOTE(mrry): The shuffle buffer size must be greater than zero, but the 1243 # list of files might be empty. 1244 buffer_size = math_ops.maximum( 1245 array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1) 1246 dataset = dataset.shuffle(buffer_size, seed=seed) 1247 return dataset 1248 1249 def repeat(self, count=None): 1250 """Repeats this dataset so each original value is seen `count` times. 1251 1252 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 1253 >>> dataset = dataset.repeat(3) 1254 >>> list(dataset.as_numpy_iterator()) 1255 [1, 2, 3, 1, 2, 3, 1, 2, 3] 1256 1257 Note: If this dataset is a function of global state (e.g. a random number 1258 generator), then different repetitions may produce different elements. 1259 1260 Args: 1261 count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 1262 number of times the dataset should be repeated. The default behavior (if 1263 `count` is `None` or `-1`) is for the dataset be repeated indefinitely. 1264 1265 Returns: 1266 Dataset: A `Dataset`. 1267 """ 1268 return RepeatDataset(self, count) 1269 1270 def enumerate(self, start=0): 1271 """Enumerates the elements of this dataset. 1272 1273 It is similar to python's `enumerate`. 1274 1275 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 1276 >>> dataset = dataset.enumerate(start=5) 1277 >>> for element in dataset.as_numpy_iterator(): 1278 ... print(element) 1279 (5, 1) 1280 (6, 2) 1281 (7, 3) 1282 1283 >>> # The (nested) structure of the input dataset determines the 1284 >>> # structure of elements in the resulting dataset. 1285 >>> dataset = tf.data.Dataset.from_tensor_slices([(7, 8), (9, 10)]) 1286 >>> dataset = dataset.enumerate() 1287 >>> for element in dataset.as_numpy_iterator(): 1288 ... print(element) 1289 (0, array([7, 8], dtype=int32)) 1290 (1, array([ 9, 10], dtype=int32)) 1291 1292 Args: 1293 start: A `tf.int64` scalar `tf.Tensor`, representing the start value for 1294 enumeration. 1295 1296 Returns: 1297 Dataset: A `Dataset`. 1298 """ 1299 1300 max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max 1301 return Dataset.zip((Dataset.range(start, max_value), self)) 1302 1303 def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None): 1304 """Randomly shuffles the elements of this dataset. 1305 1306 This dataset fills a buffer with `buffer_size` elements, then randomly 1307 samples elements from this buffer, replacing the selected elements with new 1308 elements. For perfect shuffling, a buffer size greater than or equal to the 1309 full size of the dataset is required. 1310 1311 For instance, if your dataset contains 10,000 elements but `buffer_size` is 1312 set to 1,000, then `shuffle` will initially select a random element from 1313 only the first 1,000 elements in the buffer. Once an element is selected, 1314 its space in the buffer is replaced by the next (i.e. 1,001-st) element, 1315 maintaining the 1,000 element buffer. 1316 1317 `reshuffle_each_iteration` controls whether the shuffle order should be 1318 different for each epoch. In TF 1.X, the idiomatic way to create epochs 1319 was through the `repeat` transformation: 1320 1321 ```python 1322 dataset = tf.data.Dataset.range(3) 1323 dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 1324 dataset = dataset.repeat(2) 1325 # [1, 0, 2, 1, 2, 0] 1326 1327 dataset = tf.data.Dataset.range(3) 1328 dataset = dataset.shuffle(3, reshuffle_each_iteration=False) 1329 dataset = dataset.repeat(2) 1330 # [1, 0, 2, 1, 0, 2] 1331 ``` 1332 1333 In TF 2.0, `tf.data.Dataset` objects are Python iterables which makes it 1334 possible to also create epochs through Python iteration: 1335 1336 ```python 1337 dataset = tf.data.Dataset.range(3) 1338 dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 1339 list(dataset.as_numpy_iterator()) 1340 # [1, 0, 2] 1341 list(dataset.as_numpy_iterator()) 1342 # [1, 2, 0] 1343 ``` 1344 1345 ```python 1346 dataset = tf.data.Dataset.range(3) 1347 dataset = dataset.shuffle(3, reshuffle_each_iteration=False) 1348 list(dataset.as_numpy_iterator()) 1349 # [1, 0, 2] 1350 list(dataset.as_numpy_iterator()) 1351 # [1, 0, 2] 1352 ``` 1353 1354 Args: 1355 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 1356 elements from this dataset from which the new dataset will sample. 1357 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 1358 seed that will be used to create the distribution. See 1359 `tf.random.set_seed` for behavior. 1360 reshuffle_each_iteration: (Optional.) A boolean, which if true indicates 1361 that the dataset should be pseudorandomly reshuffled each time it is 1362 iterated over. (Defaults to `True`.) 1363 1364 Returns: 1365 Dataset: A `Dataset`. 1366 """ 1367 return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration) 1368 1369 def cache(self, filename=""): 1370 """Caches the elements in this dataset. 1371 1372 The first time the dataset is iterated over, its elements will be cached 1373 either in the specified file or in memory. Subsequent iterations will 1374 use the cached data. 1375 1376 Note: For the cache to be finalized, the input dataset must be iterated 1377 through in its entirety. Otherwise, subsequent iterations will not use 1378 cached data. 1379 1380 >>> dataset = tf.data.Dataset.range(5) 1381 >>> dataset = dataset.map(lambda x: x**2) 1382 >>> dataset = dataset.cache() 1383 >>> # The first time reading through the data will generate the data using 1384 >>> # `range` and `map`. 1385 >>> list(dataset.as_numpy_iterator()) 1386 [0, 1, 4, 9, 16] 1387 >>> # Subsequent iterations read from the cache. 1388 >>> list(dataset.as_numpy_iterator()) 1389 [0, 1, 4, 9, 16] 1390 1391 When caching to a file, the cached data will persist across runs. Even the 1392 first iteration through the data will read from the cache file. Changing 1393 the input pipeline before the call to `.cache()` will have no effect until 1394 the cache file is removed or the filename is changed. 1395 1396 ```python 1397 dataset = tf.data.Dataset.range(5) 1398 dataset = dataset.cache("/path/to/file") 1399 list(dataset.as_numpy_iterator()) 1400 # [0, 1, 2, 3, 4] 1401 dataset = tf.data.Dataset.range(10) 1402 dataset = dataset.cache("/path/to/file") # Same file! 1403 list(dataset.as_numpy_iterator()) 1404 # [0, 1, 2, 3, 4] 1405 ``` 1406 1407 Note: `cache` will produce exactly the same elements during each iteration 1408 through the dataset. If you wish to randomize the iteration order, make sure 1409 to call `shuffle` *after* calling `cache`. 1410 1411 Args: 1412 filename: A `tf.string` scalar `tf.Tensor`, representing the name of a 1413 directory on the filesystem to use for caching elements in this Dataset. 1414 If a filename is not provided, the dataset will be cached in memory. 1415 1416 Returns: 1417 Dataset: A `Dataset`. 1418 """ 1419 return CacheDataset(self, filename) 1420 1421 def take(self, count): 1422 """Creates a `Dataset` with at most `count` elements from this dataset. 1423 1424 >>> dataset = tf.data.Dataset.range(10) 1425 >>> dataset = dataset.take(3) 1426 >>> list(dataset.as_numpy_iterator()) 1427 [0, 1, 2] 1428 1429 Args: 1430 count: A `tf.int64` scalar `tf.Tensor`, representing the number of 1431 elements of this dataset that should be taken to form the new dataset. 1432 If `count` is -1, or if `count` is greater than the size of this 1433 dataset, the new dataset will contain all elements of this dataset. 1434 1435 Returns: 1436 Dataset: A `Dataset`. 1437 """ 1438 return TakeDataset(self, count) 1439 1440 def skip(self, count): 1441 """Creates a `Dataset` that skips `count` elements from this dataset. 1442 1443 >>> dataset = tf.data.Dataset.range(10) 1444 >>> dataset = dataset.skip(7) 1445 >>> list(dataset.as_numpy_iterator()) 1446 [7, 8, 9] 1447 1448 Args: 1449 count: A `tf.int64` scalar `tf.Tensor`, representing the number of 1450 elements of this dataset that should be skipped to form the new dataset. 1451 If `count` is greater than the size of this dataset, the new dataset 1452 will contain no elements. If `count` is -1, skips the entire dataset. 1453 1454 Returns: 1455 Dataset: A `Dataset`. 1456 """ 1457 return SkipDataset(self, count) 1458 1459 def shard(self, num_shards, index): 1460 """Creates a `Dataset` that includes only 1/`num_shards` of this dataset. 1461 1462 `shard` is deterministic. The Dataset produced by `A.shard(n, i)` will 1463 contain all elements of A whose index mod n = i. 1464 1465 >>> A = tf.data.Dataset.range(10) 1466 >>> B = A.shard(num_shards=3, index=0) 1467 >>> list(B.as_numpy_iterator()) 1468 [0, 3, 6, 9] 1469 >>> C = A.shard(num_shards=3, index=1) 1470 >>> list(C.as_numpy_iterator()) 1471 [1, 4, 7] 1472 >>> D = A.shard(num_shards=3, index=2) 1473 >>> list(D.as_numpy_iterator()) 1474 [2, 5, 8] 1475 1476 This dataset operator is very useful when running distributed training, as 1477 it allows each worker to read a unique subset. 1478 1479 When reading a single input file, you can shard elements as follows: 1480 1481 ```python 1482 d = tf.data.TFRecordDataset(input_file) 1483 d = d.shard(num_workers, worker_index) 1484 d = d.repeat(num_epochs) 1485 d = d.shuffle(shuffle_buffer_size) 1486 d = d.map(parser_fn, num_parallel_calls=num_map_threads) 1487 ``` 1488 1489 Important caveats: 1490 1491 - Be sure to shard before you use any randomizing operator (such as 1492 shuffle). 1493 - Generally it is best if the shard operator is used early in the dataset 1494 pipeline. For example, when reading from a set of TFRecord files, shard 1495 before converting the dataset to input samples. This avoids reading every 1496 file on every worker. The following is an example of an efficient 1497 sharding strategy within a complete pipeline: 1498 1499 ```python 1500 d = Dataset.list_files(pattern) 1501 d = d.shard(num_workers, worker_index) 1502 d = d.repeat(num_epochs) 1503 d = d.shuffle(shuffle_buffer_size) 1504 d = d.interleave(tf.data.TFRecordDataset, 1505 cycle_length=num_readers, block_length=1) 1506 d = d.map(parser_fn, num_parallel_calls=num_map_threads) 1507 ``` 1508 1509 Args: 1510 num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of 1511 shards operating in parallel. 1512 index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. 1513 1514 Returns: 1515 Dataset: A `Dataset`. 1516 1517 Raises: 1518 InvalidArgumentError: if `num_shards` or `index` are illegal values. 1519 1520 Note: error checking is done on a best-effort basis, and errors aren't 1521 guaranteed to be caught upon dataset creation. (e.g. providing in a 1522 placeholder tensor bypasses the early checking, and will instead result 1523 in an error during a session.run call.) 1524 """ 1525 return ShardDataset(self, num_shards, index) 1526 1527 def batch(self, 1528 batch_size, 1529 drop_remainder=False, 1530 num_parallel_calls=None, 1531 deterministic=None): 1532 """Combines consecutive elements of this dataset into batches. 1533 1534 >>> dataset = tf.data.Dataset.range(8) 1535 >>> dataset = dataset.batch(3) 1536 >>> list(dataset.as_numpy_iterator()) 1537 [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])] 1538 1539 >>> dataset = tf.data.Dataset.range(8) 1540 >>> dataset = dataset.batch(3, drop_remainder=True) 1541 >>> list(dataset.as_numpy_iterator()) 1542 [array([0, 1, 2]), array([3, 4, 5])] 1543 1544 The components of the resulting element will have an additional outer 1545 dimension, which will be `batch_size` (or `N % batch_size` for the last 1546 element if `batch_size` does not divide the number of input elements `N` 1547 evenly and `drop_remainder` is `False`). If your program depends on the 1548 batches having the same outer dimension, you should set the `drop_remainder` 1549 argument to `True` to prevent the smaller batch from being produced. 1550 1551 Note: If your program requires data to have a statically known shape (e.g., 1552 when using XLA), you should use `drop_remainder=True`. Without 1553 `drop_remainder=True` the shape of the output dataset will have an unknown 1554 leading dimension due to the possibility of a smaller final batch. 1555 1556 Args: 1557 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 1558 consecutive elements of this dataset to combine in a single batch. 1559 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 1560 whether the last batch should be dropped in the case it has fewer than 1561 `batch_size` elements; the default behavior is not to drop the smaller 1562 batch. 1563 num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`, 1564 representing the number of batches to compute asynchronously in 1565 parallel. 1566 If not specified, batches will be computed sequentially. If the value 1567 `tf.data.AUTOTUNE` is used, then the number of parallel 1568 calls is set dynamically based on available resources. 1569 deterministic: (Optional.) When `num_parallel_calls` is specified, if this 1570 boolean is specified (`True` or `False`), it controls the order in which 1571 the transformation produces elements. If set to `False`, the 1572 transformation is allowed to yield elements out of order to trade 1573 determinism for performance. If not specified, the 1574 `tf.data.Options.deterministic` option (`True` by default) controls the 1575 behavior. 1576 1577 Returns: 1578 Dataset: A `Dataset`. 1579 """ 1580 if num_parallel_calls is None or DEBUG_MODE: 1581 if deterministic is not None and not DEBUG_MODE: 1582 warnings.warn("The `deterministic` argument has no effect unless the " 1583 "`num_parallel_calls` argument is specified.") 1584 return BatchDataset(self, batch_size, drop_remainder) 1585 else: 1586 return ParallelBatchDataset(self, batch_size, drop_remainder, 1587 num_parallel_calls, deterministic) 1588 1589 def padded_batch(self, 1590 batch_size, 1591 padded_shapes=None, 1592 padding_values=None, 1593 drop_remainder=False): 1594 """Combines consecutive elements of this dataset into padded batches. 1595 1596 This transformation combines multiple consecutive elements of the input 1597 dataset into a single element. 1598 1599 Like `tf.data.Dataset.batch`, the components of the resulting element will 1600 have an additional outer dimension, which will be `batch_size` (or 1601 `N % batch_size` for the last element if `batch_size` does not divide the 1602 number of input elements `N` evenly and `drop_remainder` is `False`). If 1603 your program depends on the batches having the same outer dimension, you 1604 should set the `drop_remainder` argument to `True` to prevent the smaller 1605 batch from being produced. 1606 1607 Unlike `tf.data.Dataset.batch`, the input elements to be batched may have 1608 different shapes, and this transformation will pad each component to the 1609 respective shape in `padded_shapes`. The `padded_shapes` argument 1610 determines the resulting shape for each dimension of each component in an 1611 output element: 1612 1613 * If the dimension is a constant, the component will be padded out to that 1614 length in that dimension. 1615 * If the dimension is unknown, the component will be padded out to the 1616 maximum length of all elements in that dimension. 1617 1618 >>> A = (tf.data.Dataset 1619 ... .range(1, 5, output_type=tf.int32) 1620 ... .map(lambda x: tf.fill([x], x))) 1621 >>> # Pad to the smallest per-batch size that fits all elements. 1622 >>> B = A.padded_batch(2) 1623 >>> for element in B.as_numpy_iterator(): 1624 ... print(element) 1625 [[1 0] 1626 [2 2]] 1627 [[3 3 3 0] 1628 [4 4 4 4]] 1629 >>> # Pad to a fixed size. 1630 >>> C = A.padded_batch(2, padded_shapes=5) 1631 >>> for element in C.as_numpy_iterator(): 1632 ... print(element) 1633 [[1 0 0 0 0] 1634 [2 2 0 0 0]] 1635 [[3 3 3 0 0] 1636 [4 4 4 4 0]] 1637 >>> # Pad with a custom value. 1638 >>> D = A.padded_batch(2, padded_shapes=5, padding_values=-1) 1639 >>> for element in D.as_numpy_iterator(): 1640 ... print(element) 1641 [[ 1 -1 -1 -1 -1] 1642 [ 2 2 -1 -1 -1]] 1643 [[ 3 3 3 -1 -1] 1644 [ 4 4 4 4 -1]] 1645 >>> # Components of nested elements can be padded independently. 1646 >>> elements = [([1, 2, 3], [10]), 1647 ... ([4, 5], [11, 12])] 1648 >>> dataset = tf.data.Dataset.from_generator( 1649 ... lambda: iter(elements), (tf.int32, tf.int32)) 1650 >>> # Pad the first component of the tuple to length 4, and the second 1651 >>> # component to the smallest size that fits. 1652 >>> dataset = dataset.padded_batch(2, 1653 ... padded_shapes=([4], [None]), 1654 ... padding_values=(-1, 100)) 1655 >>> list(dataset.as_numpy_iterator()) 1656 [(array([[ 1, 2, 3, -1], [ 4, 5, -1, -1]], dtype=int32), 1657 array([[ 10, 100], [ 11, 12]], dtype=int32))] 1658 >>> # Pad with a single value and multiple components. 1659 >>> E = tf.data.Dataset.zip((A, A)).padded_batch(2, padding_values=-1) 1660 >>> for element in E.as_numpy_iterator(): 1661 ... print(element) 1662 (array([[ 1, -1], 1663 [ 2, 2]], dtype=int32), array([[ 1, -1], 1664 [ 2, 2]], dtype=int32)) 1665 (array([[ 3, 3, 3, -1], 1666 [ 4, 4, 4, 4]], dtype=int32), array([[ 3, 3, 3, -1], 1667 [ 4, 4, 4, 4]], dtype=int32)) 1668 1669 See also `tf.data.experimental.dense_to_sparse_batch`, which combines 1670 elements that may have different shapes into a `tf.sparse.SparseTensor`. 1671 1672 Args: 1673 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 1674 consecutive elements of this dataset to combine in a single batch. 1675 padded_shapes: (Optional.) A (nested) structure of `tf.TensorShape` or 1676 `tf.int64` vector tensor-like objects representing the shape to which 1677 the respective component of each input element should be padded prior 1678 to batching. Any unknown dimensions will be padded to the maximum size 1679 of that dimension in each batch. If unset, all dimensions of all 1680 components are padded to the maximum size in the batch. `padded_shapes` 1681 must be set if any component has an unknown rank. 1682 padding_values: (Optional.) A (nested) structure of scalar-shaped 1683 `tf.Tensor`, representing the padding values to use for the respective 1684 components. None represents that the (nested) structure should be padded 1685 with default values. Defaults are `0` for numeric types and the empty 1686 string for string types. The `padding_values` should have the same 1687 (nested) structure as the input dataset. If `padding_values` is a single 1688 element and the input dataset has multiple components, then the same 1689 `padding_values` will be used to pad every component of the dataset. 1690 If `padding_values` is a scalar, then its value will be broadcasted 1691 to match the shape of each component. 1692 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 1693 whether the last batch should be dropped in the case it has fewer than 1694 `batch_size` elements; the default behavior is not to drop the smaller 1695 batch. 1696 1697 Returns: 1698 Dataset: A `Dataset`. 1699 1700 Raises: 1701 ValueError: If a component has an unknown rank, and the `padded_shapes` 1702 argument is not set. 1703 """ 1704 if padded_shapes is None: 1705 padded_shapes = get_legacy_output_shapes(self) 1706 # A `tf.TensorShape` is only false if its *rank* is unknown: 1707 # bool(tf.TensorShape(None)) is False 1708 if not all(nest.flatten(padded_shapes)): 1709 raise ValueError("You must set the `padded_shapes` argument to " 1710 "`Dataset.padded_batch` if any component of its " 1711 "input has an unknown rank") 1712 return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values, 1713 drop_remainder) 1714 1715 def map(self, map_func, num_parallel_calls=None, deterministic=None): 1716 """Maps `map_func` across the elements of this dataset. 1717 1718 This transformation applies `map_func` to each element of this dataset, and 1719 returns a new dataset containing the transformed elements, in the same 1720 order as they appeared in the input. `map_func` can be used to change both 1721 the values and the structure of a dataset's elements. Supported structure 1722 constructs are documented 1723 [here](https://www.tensorflow.org/guide/data#dataset_structure). 1724 1725 For example, `map` can be used for adding 1 to each element, or projecting a 1726 subset of element components. 1727 1728 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] 1729 >>> dataset = dataset.map(lambda x: x + 1) 1730 >>> list(dataset.as_numpy_iterator()) 1731 [2, 3, 4, 5, 6] 1732 1733 The input signature of `map_func` is determined by the structure of each 1734 element in this dataset. 1735 1736 >>> dataset = Dataset.range(5) 1737 >>> # `map_func` takes a single argument of type `tf.Tensor` with the same 1738 >>> # shape and dtype. 1739 >>> result = dataset.map(lambda x: x + 1) 1740 1741 >>> # Each element is a tuple containing two `tf.Tensor` objects. 1742 >>> elements = [(1, "foo"), (2, "bar"), (3, "baz")] 1743 >>> dataset = tf.data.Dataset.from_generator( 1744 ... lambda: elements, (tf.int32, tf.string)) 1745 >>> # `map_func` takes two arguments of type `tf.Tensor`. This function 1746 >>> # projects out just the first component. 1747 >>> result = dataset.map(lambda x_int, y_str: x_int) 1748 >>> list(result.as_numpy_iterator()) 1749 [1, 2, 3] 1750 1751 >>> # Each element is a dictionary mapping strings to `tf.Tensor` objects. 1752 >>> elements = ([{"a": 1, "b": "foo"}, 1753 ... {"a": 2, "b": "bar"}, 1754 ... {"a": 3, "b": "baz"}]) 1755 >>> dataset = tf.data.Dataset.from_generator( 1756 ... lambda: elements, {"a": tf.int32, "b": tf.string}) 1757 >>> # `map_func` takes a single argument of type `dict` with the same keys 1758 >>> # as the elements. 1759 >>> result = dataset.map(lambda d: str(d["a"]) + d["b"]) 1760 1761 The value or values returned by `map_func` determine the structure of each 1762 element in the returned dataset. 1763 1764 >>> dataset = tf.data.Dataset.range(3) 1765 >>> # `map_func` returns two `tf.Tensor` objects. 1766 >>> def g(x): 1767 ... return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"]) 1768 >>> result = dataset.map(g) 1769 >>> result.element_spec 1770 (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(3,), \ 1771dtype=tf.string, name=None)) 1772 >>> # Python primitives, lists, and NumPy arrays are implicitly converted to 1773 >>> # `tf.Tensor`. 1774 >>> def h(x): 1775 ... return 37.0, ["Foo", "Bar"], np.array([1.0, 2.0], dtype=np.float64) 1776 >>> result = dataset.map(h) 1777 >>> result.element_spec 1778 (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(2,), \ 1779dtype=tf.string, name=None), TensorSpec(shape=(2,), dtype=tf.float64, \ 1780name=None)) 1781 >>> # `map_func` can return nested structures. 1782 >>> def i(x): 1783 ... return (37.0, [42, 16]), "foo" 1784 >>> result = dataset.map(i) 1785 >>> result.element_spec 1786 ((TensorSpec(shape=(), dtype=tf.float32, name=None), 1787 TensorSpec(shape=(2,), dtype=tf.int32, name=None)), 1788 TensorSpec(shape=(), dtype=tf.string, name=None)) 1789 1790 `map_func` can accept as arguments and return any type of dataset element. 1791 1792 Note that irrespective of the context in which `map_func` is defined (eager 1793 vs. graph), tf.data traces the function and executes it as a graph. To use 1794 Python code inside of the function you have a few options: 1795 1796 1) Rely on AutoGraph to convert Python code into an equivalent graph 1797 computation. The downside of this approach is that AutoGraph can convert 1798 some but not all Python code. 1799 1800 2) Use `tf.py_function`, which allows you to write arbitrary Python code but 1801 will generally result in worse performance than 1). For example: 1802 1803 >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world']) 1804 >>> # transform a string tensor to upper case string using a Python function 1805 >>> def upper_case_fn(t: tf.Tensor): 1806 ... return t.numpy().decode('utf-8').upper() 1807 >>> d = d.map(lambda x: tf.py_function(func=upper_case_fn, 1808 ... inp=[x], Tout=tf.string)) 1809 >>> list(d.as_numpy_iterator()) 1810 [b'HELLO', b'WORLD'] 1811 1812 3) Use `tf.numpy_function`, which also allows you to write arbitrary 1813 Python code. Note that `tf.py_function` accepts `tf.Tensor` whereas 1814 `tf.numpy_function` accepts numpy arrays and returns only numpy arrays. 1815 For example: 1816 1817 >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world']) 1818 >>> def upper_case_fn(t: np.ndarray): 1819 ... return t.decode('utf-8').upper() 1820 >>> d = d.map(lambda x: tf.numpy_function(func=upper_case_fn, 1821 ... inp=[x], Tout=tf.string)) 1822 >>> list(d.as_numpy_iterator()) 1823 [b'HELLO', b'WORLD'] 1824 1825 Note that the use of `tf.numpy_function` and `tf.py_function` 1826 in general precludes the possibility of executing user-defined 1827 transformations in parallel (because of Python GIL). 1828 1829 Performance can often be improved by setting `num_parallel_calls` so that 1830 `map` will use multiple threads to process elements. If deterministic order 1831 isn't required, it can also improve performance to set 1832 `deterministic=False`. 1833 1834 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] 1835 >>> dataset = dataset.map(lambda x: x + 1, 1836 ... num_parallel_calls=tf.data.AUTOTUNE, 1837 ... deterministic=False) 1838 1839 The order of elements yielded by this transformation is deterministic if 1840 `deterministic=True`. If `map_func` contains stateful operations and 1841 `num_parallel_calls > 1`, the order in which that state is accessed is 1842 undefined, so the values of output elements may not be deterministic 1843 regardless of the `deterministic` flag value. 1844 1845 Args: 1846 map_func: A function mapping a dataset element to another dataset element. 1847 num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`, 1848 representing the number elements to process asynchronously in parallel. 1849 If not specified, elements will be processed sequentially. If the value 1850 `tf.data.AUTOTUNE` is used, then the number of parallel 1851 calls is set dynamically based on available CPU. 1852 deterministic: (Optional.) When `num_parallel_calls` is specified, if this 1853 boolean is specified (`True` or `False`), it controls the order in which 1854 the transformation produces elements. If set to `False`, the 1855 transformation is allowed to yield elements out of order to trade 1856 determinism for performance. If not specified, the 1857 `tf.data.Options.deterministic` option (`True` by default) controls the 1858 behavior. 1859 1860 Returns: 1861 Dataset: A `Dataset`. 1862 """ 1863 if num_parallel_calls is None or DEBUG_MODE: 1864 if deterministic is not None and not DEBUG_MODE: 1865 warnings.warn("The `deterministic` argument has no effect unless the " 1866 "`num_parallel_calls` argument is specified.") 1867 return MapDataset(self, map_func, preserve_cardinality=True) 1868 else: 1869 return ParallelMapDataset( 1870 self, 1871 map_func, 1872 num_parallel_calls, 1873 deterministic, 1874 preserve_cardinality=True) 1875 1876 def flat_map(self, map_func): 1877 """Maps `map_func` across this dataset and flattens the result. 1878 1879 The type signature is: 1880 1881 ``` 1882 def flat_map( 1883 self: Dataset[T], 1884 map_func: Callable[[T], Dataset[S]] 1885 ) -> Dataset[S] 1886 ``` 1887 1888 Use `flat_map` if you want to make sure that the order of your dataset 1889 stays the same. For example, to flatten a dataset of batches into a 1890 dataset of their elements: 1891 1892 >>> dataset = tf.data.Dataset.from_tensor_slices( 1893 ... [[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 1894 >>> dataset = dataset.flat_map( 1895 ... lambda x: tf.data.Dataset.from_tensor_slices(x)) 1896 >>> list(dataset.as_numpy_iterator()) 1897 [1, 2, 3, 4, 5, 6, 7, 8, 9] 1898 1899 `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since 1900 `flat_map` produces the same output as 1901 `tf.data.Dataset.interleave(cycle_length=1)` 1902 1903 Args: 1904 map_func: A function mapping a dataset element to a dataset. 1905 1906 Returns: 1907 Dataset: A `Dataset`. 1908 """ 1909 return FlatMapDataset(self, map_func) 1910 1911 def interleave(self, 1912 map_func, 1913 cycle_length=None, 1914 block_length=None, 1915 num_parallel_calls=None, 1916 deterministic=None): 1917 """Maps `map_func` across this dataset, and interleaves the results. 1918 1919 The type signature is: 1920 1921 ``` 1922 def interleave( 1923 self: Dataset[T], 1924 map_func: Callable[[T], Dataset[S]] 1925 ) -> Dataset[S] 1926 ``` 1927 1928 For example, you can use `Dataset.interleave()` to process many input files 1929 concurrently: 1930 1931 >>> # Preprocess 4 files concurrently, and interleave blocks of 16 records 1932 >>> # from each file. 1933 >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt", 1934 ... "/var/data/file3.txt", "/var/data/file4.txt"] 1935 >>> dataset = tf.data.Dataset.from_tensor_slices(filenames) 1936 >>> def parse_fn(filename): 1937 ... return tf.data.Dataset.range(10) 1938 >>> dataset = dataset.interleave(lambda x: 1939 ... tf.data.TextLineDataset(x).map(parse_fn, num_parallel_calls=1), 1940 ... cycle_length=4, block_length=16) 1941 1942 The `cycle_length` and `block_length` arguments control the order in which 1943 elements are produced. `cycle_length` controls the number of input elements 1944 that are processed concurrently. If you set `cycle_length` to 1, this 1945 transformation will handle one input element at a time, and will produce 1946 identical results to `tf.data.Dataset.flat_map`. In general, 1947 this transformation will apply `map_func` to `cycle_length` input elements, 1948 open iterators on the returned `Dataset` objects, and cycle through them 1949 producing `block_length` consecutive elements from each iterator, and 1950 consuming the next input element each time it reaches the end of an 1951 iterator. 1952 1953 For example: 1954 1955 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] 1956 >>> # NOTE: New lines indicate "block" boundaries. 1957 >>> dataset = dataset.interleave( 1958 ... lambda x: Dataset.from_tensors(x).repeat(6), 1959 ... cycle_length=2, block_length=4) 1960 >>> list(dataset.as_numpy_iterator()) 1961 [1, 1, 1, 1, 1962 2, 2, 2, 2, 1963 1, 1, 1964 2, 2, 1965 3, 3, 3, 3, 1966 4, 4, 4, 4, 1967 3, 3, 1968 4, 4, 1969 5, 5, 5, 5, 1970 5, 5] 1971 1972 Note: The order of elements yielded by this transformation is 1973 deterministic, as long as `map_func` is a pure function and 1974 `deterministic=True`. If `map_func` contains any stateful operations, the 1975 order in which that state is accessed is undefined. 1976 1977 Performance can often be improved by setting `num_parallel_calls` so that 1978 `interleave` will use multiple threads to fetch elements. If determinism 1979 isn't required, it can also improve performance to set 1980 `deterministic=False`. 1981 1982 >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt", 1983 ... "/var/data/file3.txt", "/var/data/file4.txt"] 1984 >>> dataset = tf.data.Dataset.from_tensor_slices(filenames) 1985 >>> dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x), 1986 ... cycle_length=4, num_parallel_calls=tf.data.AUTOTUNE, 1987 ... deterministic=False) 1988 1989 Args: 1990 map_func: A function that takes a dataset element and returns a 1991 `tf.data.Dataset`. 1992 cycle_length: (Optional.) The number of input elements that will be 1993 processed concurrently. If not set, the tf.data runtime decides what it 1994 should be based on available CPU. If `num_parallel_calls` is set to 1995 `tf.data.AUTOTUNE`, the `cycle_length` argument identifies 1996 the maximum degree of parallelism. 1997 block_length: (Optional.) The number of consecutive elements to produce 1998 from each input element before cycling to another input element. If not 1999 set, defaults to 1. 2000 num_parallel_calls: (Optional.) If specified, the implementation creates a 2001 threadpool, which is used to fetch inputs from cycle elements 2002 asynchronously and in parallel. The default behavior is to fetch inputs 2003 from cycle elements synchronously with no parallelism. If the value 2004 `tf.data.AUTOTUNE` is used, then the number of parallel 2005 calls is set dynamically based on available CPU. 2006 deterministic: (Optional.) When `num_parallel_calls` is specified, if this 2007 boolean is specified (`True` or `False`), it controls the order in which 2008 the transformation produces elements. If set to `False`, the 2009 transformation is allowed to yield elements out of order to trade 2010 determinism for performance. If not specified, the 2011 `tf.data.Options.deterministic` option (`True` by default) controls the 2012 behavior. 2013 2014 Returns: 2015 Dataset: A `Dataset`. 2016 """ 2017 if block_length is None: 2018 block_length = 1 2019 2020 if cycle_length is None: 2021 cycle_length = AUTOTUNE 2022 2023 if num_parallel_calls is None or DEBUG_MODE: 2024 if deterministic is not None and not DEBUG_MODE: 2025 warnings.warn("The `deterministic` argument has no effect unless the " 2026 "`num_parallel_calls` argument is specified.") 2027 return InterleaveDataset(self, map_func, cycle_length, block_length) 2028 else: 2029 return ParallelInterleaveDataset( 2030 self, 2031 map_func, 2032 cycle_length, 2033 block_length, 2034 num_parallel_calls, 2035 deterministic=deterministic) 2036 2037 def filter(self, predicate): 2038 """Filters this dataset according to `predicate`. 2039 2040 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 2041 >>> dataset = dataset.filter(lambda x: x < 3) 2042 >>> list(dataset.as_numpy_iterator()) 2043 [1, 2] 2044 >>> # `tf.math.equal(x, y)` is required for equality comparison 2045 >>> def filter_fn(x): 2046 ... return tf.math.equal(x, 1) 2047 >>> dataset = dataset.filter(filter_fn) 2048 >>> list(dataset.as_numpy_iterator()) 2049 [1] 2050 2051 Args: 2052 predicate: A function mapping a dataset element to a boolean. 2053 2054 Returns: 2055 Dataset: The `Dataset` containing the elements of this dataset for which 2056 `predicate` is `True`. 2057 """ 2058 return FilterDataset(self, predicate) 2059 2060 def apply(self, transformation_func): 2061 """Applies a transformation function to this dataset. 2062 2063 `apply` enables chaining of custom `Dataset` transformations, which are 2064 represented as functions that take one `Dataset` argument and return a 2065 transformed `Dataset`. 2066 2067 >>> dataset = tf.data.Dataset.range(100) 2068 >>> def dataset_fn(ds): 2069 ... return ds.filter(lambda x: x < 5) 2070 >>> dataset = dataset.apply(dataset_fn) 2071 >>> list(dataset.as_numpy_iterator()) 2072 [0, 1, 2, 3, 4] 2073 2074 Args: 2075 transformation_func: A function that takes one `Dataset` argument and 2076 returns a `Dataset`. 2077 2078 Returns: 2079 Dataset: The `Dataset` returned by applying `transformation_func` to this 2080 dataset. 2081 """ 2082 dataset = transformation_func(self) 2083 if not isinstance(dataset, DatasetV2): 2084 raise TypeError( 2085 "`transformation_func` must return a Dataset. Got {}.".format( 2086 dataset)) 2087 dataset._input_datasets = [self] # pylint: disable=protected-access 2088 return dataset 2089 2090 def window(self, size, shift=None, stride=1, drop_remainder=False): 2091 """Returns a dataset of "windows". 2092 2093 Each "window" is a dataset that contains a subset of elements of the 2094 input dataset. These are finite datasets of size `size` (or possibly fewer 2095 if there are not enough input elements to fill the window and 2096 `drop_remainder` evaluates to `False`). 2097 2098 For example: 2099 2100 >>> dataset = tf.data.Dataset.range(7).window(3) 2101 >>> for window in dataset: 2102 ... print(window) 2103 <...Dataset shapes: (), types: tf.int64> 2104 <...Dataset shapes: (), types: tf.int64> 2105 <...Dataset shapes: (), types: tf.int64> 2106 2107 Since windows are datasets, they can be iterated over: 2108 2109 >>> for window in dataset: 2110 ... print([item.numpy() for item in window]) 2111 [0, 1, 2] 2112 [3, 4, 5] 2113 [6] 2114 2115 #### Shift 2116 2117 The `shift` argument determines the number of input elements to shift 2118 between the start of each window. If windows and elements are both numbered 2119 starting at 0, the first element in window `k` will be element `k * shift` 2120 of the input dataset. In particular, the first element of the first window 2121 will always be the first element of the input dataset. 2122 2123 >>> dataset = tf.data.Dataset.range(7).window(3, shift=1, 2124 ... drop_remainder=True) 2125 >>> for window in dataset: 2126 ... print(list(window.as_numpy_iterator())) 2127 [0, 1, 2] 2128 [1, 2, 3] 2129 [2, 3, 4] 2130 [3, 4, 5] 2131 [4, 5, 6] 2132 2133 #### Stride 2134 2135 The `stride` argument determines the stride between input elements within a 2136 window. 2137 2138 >>> dataset = tf.data.Dataset.range(7).window(3, shift=1, stride=2, 2139 ... drop_remainder=True) 2140 >>> for window in dataset: 2141 ... print(list(window.as_numpy_iterator())) 2142 [0, 2, 4] 2143 [1, 3, 5] 2144 [2, 4, 6] 2145 2146 #### Nested elements 2147 2148 When the `window` transformation is applied to a dataset whos elements are 2149 nested structures, it produces a dataset where the elements have the same 2150 nested structure but each leaf is replaced by a window. In other words, 2151 the nesting is applied outside of the windows as opposed inside of them. 2152 2153 The type signature is: 2154 2155 ``` 2156 def window( 2157 self: Dataset[Nest[T]], ... 2158 ) -> Dataset[Nest[Dataset[T]]] 2159 ``` 2160 2161 Applying `window` to a `Dataset` of tuples gives a tuple of windows: 2162 2163 >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5], 2164 ... [6, 7, 8, 9, 10])) 2165 >>> dataset = dataset.window(2) 2166 >>> windows = next(iter(dataset)) 2167 >>> windows 2168 (<...Dataset shapes: (), types: tf.int32>, 2169 <...Dataset shapes: (), types: tf.int32>) 2170 2171 >>> def to_numpy(ds): 2172 ... return list(ds.as_numpy_iterator()) 2173 >>> 2174 >>> for windows in dataset: 2175 ... print(to_numpy(windows[0]), to_numpy(windows[1])) 2176 [1, 2] [6, 7] 2177 [3, 4] [8, 9] 2178 [5] [10] 2179 2180 Applying `window` to a `Dataset` of dictionaries gives a dictionary of 2181 `Datasets`: 2182 2183 >>> dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3], 2184 ... 'b': [4, 5, 6], 2185 ... 'c': [7, 8, 9]}) 2186 >>> dataset = dataset.window(2) 2187 >>> def to_numpy(ds): 2188 ... return list(ds.as_numpy_iterator()) 2189 >>> 2190 >>> for windows in dataset: 2191 ... print(tf.nest.map_structure(to_numpy, windows)) 2192 {'a': [1, 2], 'b': [4, 5], 'c': [7, 8]} 2193 {'a': [3], 'b': [6], 'c': [9]} 2194 2195 #### Flatten a dataset of windows 2196 2197 The `Dataset.flat_map` and `Dataset.interleave` methods can be used to 2198 flatten a dataset of windows into a single dataset. 2199 2200 The argument to `flat_map` is a function that takes an element from the 2201 dataset and returns a `Dataset`. `flat_map` chains together the resulting 2202 datasets sequentially. 2203 2204 For example, to turn each window into a dense tensor: 2205 2206 >>> size = 3 2207 >>> dataset = tf.data.Dataset.range(7).window(size, shift=1, 2208 ... drop_remainder=True) 2209 >>> batched = dataset.flat_map(lambda x:x.batch(3)) 2210 >>> for batch in batched: 2211 ... print(batch.numpy()) 2212 [0 1 2] 2213 [1 2 3] 2214 [2 3 4] 2215 [3 4 5] 2216 [4 5 6] 2217 2218 Args: 2219 size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements 2220 of the input dataset to combine into a window. Must be positive. 2221 shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 2222 number of input elements by which the window moves in each iteration. 2223 Defaults to `size`. Must be positive. 2224 stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 2225 stride of the input elements in the sliding window. Must be positive. 2226 The default value of 1 means "retain every input element". 2227 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 2228 whether the last windows should be dropped if their size is smaller than 2229 `size`. 2230 2231 Returns: 2232 Dataset: A `Dataset` of (nests of) windows. Each window is a finite 2233 datasets of flat elements. 2234 """ 2235 if shift is None: 2236 shift = size 2237 return WindowDataset(self, size, shift, stride, drop_remainder) 2238 2239 def reduce(self, initial_state, reduce_func): 2240 """Reduces the input dataset to a single element. 2241 2242 The transformation calls `reduce_func` successively on every element of 2243 the input dataset until the dataset is exhausted, aggregating information in 2244 its internal state. The `initial_state` argument is used for the initial 2245 state and the final state is returned as the result. 2246 2247 >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1).numpy() 2248 5 2249 >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y).numpy() 2250 10 2251 2252 Args: 2253 initial_state: An element representing the initial state of the 2254 transformation. 2255 reduce_func: A function that maps `(old_state, input_element)` to 2256 `new_state`. It must take two arguments and return a new element 2257 The structure of `new_state` must match the structure of 2258 `initial_state`. 2259 2260 Returns: 2261 A dataset element corresponding to the final state of the transformation. 2262 2263 """ 2264 2265 with ops.name_scope("initial_state"): 2266 initial_state = structure.normalize_element(initial_state) 2267 state_structure = structure.type_spec_from_value(initial_state) 2268 2269 # Iteratively rerun the reduce function until reaching a fixed point on 2270 # `state_structure`. 2271 need_to_rerun = True 2272 while need_to_rerun: 2273 2274 wrapped_func = StructuredFunctionWrapper( 2275 reduce_func, 2276 "reduce()", 2277 input_structure=(state_structure, self.element_spec), 2278 add_to_graph=False) 2279 2280 # Extract and validate class information from the returned values. 2281 output_classes = wrapped_func.output_classes 2282 state_classes = nest.map_structure( 2283 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 2284 state_structure) 2285 for new_state_class, state_class in zip( 2286 nest.flatten(output_classes), nest.flatten(state_classes)): 2287 if not issubclass(new_state_class, state_class): 2288 raise TypeError( 2289 "The element classes for the new state must match the initial " 2290 "state. Expected %s; got %s." % 2291 (state_classes, wrapped_func.output_classes)) 2292 2293 # Extract and validate type information from the returned values. 2294 output_types = wrapped_func.output_types 2295 state_types = nest.map_structure( 2296 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 2297 state_structure) 2298 for new_state_type, state_type in zip( 2299 nest.flatten(output_types), nest.flatten(state_types)): 2300 if new_state_type != state_type: 2301 raise TypeError( 2302 "The element types for the new state must match the initial " 2303 "state. Expected %s; got %s." % 2304 (state_types, wrapped_func.output_types)) 2305 2306 # Extract shape information from the returned values. 2307 output_shapes = wrapped_func.output_shapes 2308 state_shapes = nest.map_structure( 2309 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 2310 state_structure) 2311 flat_state_shapes = nest.flatten(state_shapes) 2312 flat_new_state_shapes = nest.flatten(output_shapes) 2313 weakened_state_shapes = [ 2314 original.most_specific_compatible_shape(new) 2315 for original, new in zip(flat_state_shapes, flat_new_state_shapes) 2316 ] 2317 2318 need_to_rerun = False 2319 for original_shape, weakened_shape in zip(flat_state_shapes, 2320 weakened_state_shapes): 2321 if original_shape.ndims is not None and ( 2322 weakened_shape.ndims is None or 2323 original_shape.as_list() != weakened_shape.as_list()): 2324 need_to_rerun = True 2325 break 2326 2327 if need_to_rerun: 2328 # TODO(b/110122868): Support a "most specific compatible structure" 2329 # method for combining structures, to avoid using legacy structures 2330 # here. 2331 state_structure = structure.convert_legacy_structure( 2332 state_types, 2333 nest.pack_sequence_as(state_shapes, weakened_state_shapes), 2334 state_classes) 2335 2336 reduce_func = wrapped_func.function 2337 reduce_func.add_to_graph(ops.get_default_graph()) 2338 2339 dataset = self._apply_debug_options() 2340 2341 # pylint: disable=protected-access 2342 return structure.from_compatible_tensor_list( 2343 state_structure, 2344 gen_dataset_ops.reduce_dataset( 2345 dataset._variant_tensor, 2346 structure.to_tensor_list(state_structure, initial_state), 2347 reduce_func.captured_inputs, 2348 f=reduce_func, 2349 output_shapes=structure.get_flat_tensor_shapes(state_structure), 2350 output_types=structure.get_flat_tensor_types(state_structure))) 2351 2352 def get_single_element(self): 2353 """Returns the single element of the `dataset` as a nested structure of tensors. 2354 2355 The function enables you to use a `tf.data.Dataset` in a stateless 2356 "tensor-in tensor-out" expression, without creating an iterator. 2357 This facilitates the ease of data transformation on tensors using the 2358 optimized `tf.data.Dataset` abstraction on top of them. 2359 2360 For example, lets consider a `preprocessing_fn` which would take as an 2361 input the raw features and returns the processed feature along with 2362 it's label. 2363 2364 ```python 2365 def preprocessing_fn(raw_feature): 2366 # ... the raw_feature is preprocessed as per the use-case 2367 return feature 2368 2369 raw_features = ... # input batch of BATCH_SIZE elements. 2370 dataset = (tf.data.Dataset.from_tensor_slices(raw_features) 2371 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 2372 .batch(BATCH_SIZE)) 2373 2374 processed_features = dataset.get_single_element() 2375 ``` 2376 2377 In the above example, the `raw_features` tensor of length=BATCH_SIZE 2378 was converted to a `tf.data.Dataset`. Next, each of the `raw_feature` was 2379 mapped using the `preprocessing_fn` and the processed features were 2380 grouped into a single batch. The final `dataset` contains only one element 2381 which is a batch of all the processed features. 2382 2383 NOTE: The `dataset` should contain only one element. 2384 2385 Now, instead of creating an iterator for the `dataset` and retrieving the 2386 batch of features, the `tf.data.get_single_element()` function is used 2387 to skip the iterator creation process and directly output the batch of 2388 features. 2389 2390 This can be particularly useful when your tensor transformations are 2391 expressed as `tf.data.Dataset` operations, and you want to use those 2392 transformations while serving your model. 2393 2394 #### Keras 2395 2396 ```python 2397 2398 model = ... # A pre-built or custom model 2399 2400 class PreprocessingModel(tf.keras.Model): 2401 def __init__(self, model): 2402 super().__init__(self) 2403 self.model = model 2404 2405 @tf.function(input_signature=[...]) 2406 def serving_fn(self, data): 2407 ds = tf.data.Dataset.from_tensor_slices(data) 2408 ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 2409 ds = ds.batch(batch_size=BATCH_SIZE) 2410 return tf.argmax(self.model(ds.get_single_element()), axis=-1) 2411 2412 preprocessing_model = PreprocessingModel(model) 2413 your_exported_model_dir = ... # save the model to this path. 2414 tf.saved_model.save(preprocessing_model, your_exported_model_dir, 2415 signatures={'serving_default': preprocessing_model.serving_fn} 2416 ) 2417 ``` 2418 2419 #### Estimator 2420 2421 In the case of estimators, you need to generally define a `serving_input_fn` 2422 which would require the features to be processed by the model while 2423 inferencing. 2424 2425 ```python 2426 def serving_input_fn(): 2427 2428 raw_feature_spec = ... # Spec for the raw_features 2429 input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( 2430 raw_feature_spec, default_batch_size=None) 2431 ) 2432 serving_input_receiver = input_fn() 2433 raw_features = serving_input_receiver.features 2434 2435 def preprocessing_fn(raw_feature): 2436 # ... the raw_feature is preprocessed as per the use-case 2437 return feature 2438 2439 dataset = (tf.data.Dataset.from_tensor_slices(raw_features) 2440 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 2441 .batch(BATCH_SIZE)) 2442 2443 processed_features = dataset.get_single_element() 2444 2445 # Please note that the value of `BATCH_SIZE` should be equal to 2446 # the size of the leading dimension of `raw_features`. This ensures 2447 # that `dataset` has only element, which is a pre-requisite for 2448 # using `dataset.get_single_element()`. 2449 2450 return tf.estimator.export.ServingInputReceiver( 2451 processed_features, serving_input_receiver.receiver_tensors) 2452 2453 estimator = ... # A pre-built or custom estimator 2454 estimator.export_saved_model(your_exported_model_dir, serving_input_fn) 2455 ``` 2456 2457 Returns: 2458 A nested structure of `tf.Tensor` objects, corresponding to the single 2459 element of `dataset`. 2460 2461 Raises: 2462 InvalidArgumentError: (at runtime) if `dataset` does not contain exactly 2463 one element. 2464 """ 2465 2466 return structure.from_compatible_tensor_list( 2467 self.element_spec, 2468 gen_dataset_ops.dataset_to_single_element(self._variant_tensor, 2469 **self._flat_structure)) # pylint: disable=protected-access 2470 2471 def unbatch(self): 2472 """Splits elements of a dataset into multiple elements. 2473 2474 For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, 2475 where `B` may vary for each input element, then for each element in the 2476 dataset, the unbatched dataset will contain `B` consecutive elements 2477 of shape `[a0, a1, ...]`. 2478 2479 >>> elements = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ] 2480 >>> dataset = tf.data.Dataset.from_generator(lambda: elements, tf.int64) 2481 >>> dataset = dataset.unbatch() 2482 >>> list(dataset.as_numpy_iterator()) 2483 [1, 2, 3, 1, 2, 1, 2, 3, 4] 2484 2485 Note: `unbatch` requires a data copy to slice up the batched tensor into 2486 smaller, unbatched tensors. When optimizing performance, try to avoid 2487 unnecessary usage of `unbatch`. 2488 2489 Returns: 2490 A `Dataset`. 2491 """ 2492 normalized_dataset = normalize_to_dense(self) 2493 return _UnbatchDataset(normalized_dataset) 2494 2495 def with_options(self, options): 2496 """Returns a new `tf.data.Dataset` with the given options set. 2497 2498 The options are "global" in the sense they apply to the entire dataset. 2499 If options are set multiple times, they are merged as long as different 2500 options do not use different non-default values. 2501 2502 >>> ds = tf.data.Dataset.range(5) 2503 >>> ds = ds.interleave(lambda x: tf.data.Dataset.range(5), 2504 ... cycle_length=3, 2505 ... num_parallel_calls=3) 2506 >>> options = tf.data.Options() 2507 >>> # This will make the interleave order non-deterministic. 2508 >>> options.deterministic = False 2509 >>> ds = ds.with_options(options) 2510 2511 Args: 2512 options: A `tf.data.Options` that identifies the options the use. 2513 2514 Returns: 2515 Dataset: A `Dataset` with the given options. 2516 2517 Raises: 2518 ValueError: when an option is set more than once to a non-default value 2519 """ 2520 return _OptionsDataset(self, options) 2521 2522 def cardinality(self): 2523 """Returns the cardinality of the dataset, if known. 2524 2525 `cardinality` may return `tf.data.INFINITE_CARDINALITY` if the dataset 2526 contains an infinite number of elements or `tf.data.UNKNOWN_CARDINALITY` if 2527 the analysis fails to determine the number of elements in the dataset 2528 (e.g. when the dataset source is a file). 2529 2530 >>> dataset = tf.data.Dataset.range(42) 2531 >>> print(dataset.cardinality().numpy()) 2532 42 2533 >>> dataset = dataset.repeat() 2534 >>> cardinality = dataset.cardinality() 2535 >>> print((cardinality == tf.data.INFINITE_CARDINALITY).numpy()) 2536 True 2537 >>> dataset = dataset.filter(lambda x: True) 2538 >>> cardinality = dataset.cardinality() 2539 >>> print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy()) 2540 True 2541 2542 Returns: 2543 A scalar `tf.int64` `Tensor` representing the cardinality of the dataset. 2544 If the cardinality is infinite or unknown, `cardinality` returns the 2545 named constants `tf.data.INFINITE_CARDINALITY` and 2546 `tf.data.UNKNOWN_CARDINALITY` respectively. 2547 """ 2548 return gen_dataset_ops.dataset_cardinality(self._variant_tensor) 2549 2550 def group_by_window(self, 2551 key_func, 2552 reduce_func, 2553 window_size=None, 2554 window_size_func=None): 2555 """Groups windows of elements by key and reduces them. 2556 2557 This transformation maps each consecutive element in a dataset to a key 2558 using `key_func` and groups the elements by key. It then applies 2559 `reduce_func` to at most `window_size_func(key)` elements matching the same 2560 key. All except the final window for each key will contain 2561 `window_size_func(key)` elements; the final window may be smaller. 2562 2563 You may provide either a constant `window_size` or a window size determined 2564 by the key through `window_size_func`. 2565 2566 >>> dataset = tf.data.Dataset.range(10) 2567 >>> window_size = 5 2568 >>> key_func = lambda x: x%2 2569 >>> reduce_func = lambda key, dataset: dataset.batch(window_size) 2570 >>> dataset = dataset.group_by_window( 2571 ... key_func=key_func, 2572 ... reduce_func=reduce_func, 2573 ... window_size=window_size) 2574 >>> for elem in dataset.as_numpy_iterator(): 2575 ... print(elem) 2576 [0 2 4 6 8] 2577 [1 3 5 7 9] 2578 2579 Args: 2580 key_func: A function mapping a nested structure of tensors (having shapes 2581 and types defined by `self.output_shapes` and `self.output_types`) to a 2582 scalar `tf.int64` tensor. 2583 reduce_func: A function mapping a key and a dataset of up to `window_size` 2584 consecutive elements matching that key to another dataset. 2585 window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 2586 consecutive elements matching the same key to combine in a single batch, 2587 which will be passed to `reduce_func`. Mutually exclusive with 2588 `window_size_func`. 2589 window_size_func: A function mapping a key to a `tf.int64` scalar 2590 `tf.Tensor`, representing the number of consecutive elements matching 2591 the same key to combine in a single batch, which will be passed to 2592 `reduce_func`. Mutually exclusive with `window_size`. 2593 2594 Returns: 2595 A `Dataset`. 2596 2597 Raises: 2598 ValueError: if neither or both of {`window_size`, `window_size_func`} are 2599 passed. 2600 """ 2601 if (window_size is not None and window_size_func or 2602 not (window_size is not None or window_size_func)): 2603 raise ValueError("Must pass either window_size or window_size_func.") 2604 2605 if window_size is not None: 2606 2607 def constant_window_func(unused_key): 2608 return ops.convert_to_tensor(window_size, dtype=dtypes.int64) 2609 2610 window_size_func = constant_window_func 2611 2612 assert window_size_func is not None 2613 2614 return _GroupByWindowDataset(self, key_func, reduce_func, window_size_func) 2615 2616 def bucket_by_sequence_length(self, 2617 element_length_func, 2618 bucket_boundaries, 2619 bucket_batch_sizes, 2620 padded_shapes=None, 2621 padding_values=None, 2622 pad_to_bucket_boundary=False, 2623 no_padding=False, 2624 drop_remainder=False): 2625 """A transformation that buckets elements in a `Dataset` by length. 2626 2627 Elements of the `Dataset` are grouped together by length and then are padded 2628 and batched. 2629 2630 This is useful for sequence tasks in which the elements have variable 2631 length. Grouping together elements that have similar lengths reduces the 2632 total fraction of padding in a batch which increases training step 2633 efficiency. 2634 2635 Below is an example to bucketize the input data to the 3 buckets 2636 "[0, 3), [3, 5), [5, inf)" based on sequence length, with batch size 2. 2637 2638 >>> elements = [ 2639 ... [0], [1, 2, 3, 4], [5, 6, 7], 2640 ... [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]] 2641 >>> dataset = tf.data.Dataset.from_generator( 2642 ... lambda: elements, tf.int64, output_shapes=[None]) 2643 >>> dataset = dataset.bucket_by_sequence_length( 2644 ... element_length_func=lambda elem: tf.shape(elem)[0], 2645 ... bucket_boundaries=[3, 5], 2646 ... bucket_batch_sizes=[2, 2, 2]) 2647 >>> for elem in dataset.as_numpy_iterator(): 2648 ... print(elem) 2649 [[1 2 3 4] 2650 [5 6 7 0]] 2651 [[ 7 8 9 10 11 0] 2652 [13 14 15 16 19 20]] 2653 [[ 0 0] 2654 [21 22]] 2655 2656 Args: 2657 element_length_func: function from element in `Dataset` to `tf.int32`, 2658 determines the length of the element, which will determine the bucket it 2659 goes into. 2660 bucket_boundaries: `list<int>`, upper length boundaries of the buckets. 2661 bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be 2662 `len(bucket_boundaries) + 1`. 2663 padded_shapes: Nested structure of `tf.TensorShape` to pass to 2664 `tf.data.Dataset.padded_batch`. If not provided, will use 2665 `dataset.output_shapes`, which will result in variable length dimensions 2666 being padded out to the maximum length in each batch. 2667 padding_values: Values to pad with, passed to 2668 `tf.data.Dataset.padded_batch`. Defaults to padding with 0. 2669 pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown 2670 size to maximum length in batch. If `True`, will pad dimensions with 2671 unknown size to bucket boundary minus 1 (i.e., the maximum length in 2672 each bucket), and caller must ensure that the source `Dataset` does not 2673 contain any elements with length longer than `max(bucket_boundaries)`. 2674 no_padding: `bool`, indicates whether to pad the batch features (features 2675 need to be either of type `tf.sparse.SparseTensor` or of same shape). 2676 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 2677 whether the last batch should be dropped in the case it has fewer than 2678 `batch_size` elements; the default behavior is not to drop the smaller 2679 batch. 2680 2681 Returns: 2682 A `Dataset`. 2683 2684 Raises: 2685 ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`. 2686 """ 2687 if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1): 2688 raise ValueError( 2689 "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1") 2690 2691 batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64) 2692 2693 def element_to_bucket_id(*args): 2694 """Return int64 id of the length bucket for this element.""" 2695 seq_length = element_length_func(*args) 2696 2697 boundaries = list(bucket_boundaries) 2698 buckets_min = [np.iinfo(np.int32).min] + boundaries 2699 buckets_max = boundaries + [np.iinfo(np.int32).max] 2700 conditions_c = math_ops.logical_and( 2701 math_ops.less_equal(buckets_min, seq_length), 2702 math_ops.less(seq_length, buckets_max)) 2703 bucket_id = math_ops.reduce_min(array_ops.where(conditions_c)) 2704 2705 return bucket_id 2706 2707 def window_size_fn(bucket_id): 2708 # The window size is set to the batch size for this bucket 2709 window_size = batch_sizes[bucket_id] 2710 return window_size 2711 2712 def make_padded_shapes(shapes, none_filler=None): 2713 padded = [] 2714 for shape in nest.flatten(shapes): 2715 shape = tensor_shape.TensorShape(shape) 2716 shape = [ 2717 none_filler if tensor_shape.dimension_value(d) is None else d 2718 for d in shape 2719 ] 2720 padded.append(shape) 2721 return nest.pack_sequence_as(shapes, padded) 2722 2723 def batching_fn(bucket_id, grouped_dataset): 2724 """Batch elements in dataset.""" 2725 batch_size = window_size_fn(bucket_id) 2726 if no_padding: 2727 return grouped_dataset.batch(batch_size, drop_remainder=drop_remainder) 2728 none_filler = None 2729 if pad_to_bucket_boundary: 2730 err_msg = ("When pad_to_bucket_boundary=True, elements must have " 2731 "length < max(bucket_boundaries).") 2732 check = check_ops.assert_less( 2733 bucket_id, 2734 constant_op.constant( 2735 len(bucket_batch_sizes) - 1, dtype=dtypes.int64), 2736 message=err_msg) 2737 with ops.control_dependencies([check]): 2738 boundaries = constant_op.constant( 2739 bucket_boundaries, dtype=dtypes.int64) 2740 bucket_boundary = boundaries[bucket_id] 2741 none_filler = bucket_boundary - 1 2742 input_shapes = get_legacy_output_shapes(grouped_dataset) 2743 shapes = make_padded_shapes( 2744 padded_shapes or input_shapes, none_filler=none_filler) 2745 return grouped_dataset.padded_batch( 2746 batch_size, shapes, padding_values, drop_remainder=drop_remainder) 2747 2748 return self.group_by_window( 2749 key_func=element_to_bucket_id, 2750 reduce_func=batching_fn, 2751 window_size_func=window_size_fn) 2752 2753 @staticmethod 2754 def random(seed=None): 2755 """Creates a `Dataset` of pseudorandom values. 2756 2757 The dataset generates a sequence of uniformly distributed integer values. 2758 2759 >>> ds1 = tf.data.Dataset.random(seed=4).take(10) 2760 >>> ds2 = tf.data.Dataset.random(seed=4).take(10) 2761 >>> print(list(ds2.as_numpy_iterator())==list(ds2.as_numpy_iterator())) 2762 True 2763 2764 Args: 2765 seed: (Optional) If specified, the dataset produces a deterministic 2766 sequence of values. 2767 2768 Returns: 2769 Dataset: A `Dataset`. 2770 """ 2771 return RandomDataset(seed=seed) 2772 2773 def snapshot(self, 2774 path, 2775 compression="AUTO", 2776 reader_func=None, 2777 shard_func=None): 2778 """API to persist the output of the input dataset. 2779 2780 The snapshot API allows users to transparently persist the output of their 2781 preprocessing pipeline to disk, and materialize the pre-processed data on a 2782 different training run. 2783 2784 This API enables repeated preprocessing steps to be consolidated, and allows 2785 re-use of already processed data, trading off disk storage and network 2786 bandwidth for freeing up more valuable CPU resources and accelerator compute 2787 time. 2788 2789 https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md 2790 has detailed design documentation of this feature. 2791 2792 Users can specify various options to control the behavior of snapshot, 2793 including how snapshots are read from and written to by passing in 2794 user-defined functions to the `reader_func` and `shard_func` parameters. 2795 2796 `shard_func` is a user specified function that maps input elements to 2797 snapshot shards. 2798 2799 Users may want to specify this function to control how snapshot files should 2800 be written to disk. Below is an example of how a potential `shard_func` 2801 could be written. 2802 2803 ```python 2804 dataset = ... 2805 dataset = dataset.enumerate() 2806 dataset = dataset.snapshot("/path/to/snapshot/dir", 2807 shard_func=lambda x, y: x % NUM_SHARDS, ...) 2808 dataset = dataset.map(lambda x, y: y) 2809 ``` 2810 2811 `reader_func` is a user specified function that accepts a single argument: 2812 (1) a Dataset of Datasets, each representing a "split" of elements of the 2813 original dataset. The cardinality of the input dataset matches the 2814 number of the shards specified in the `shard_func` (see above). The function 2815 should return a Dataset of elements of the original dataset. 2816 2817 Users may want specify this function to control how snapshot files should be 2818 read from disk, including the amount of shuffling and parallelism. 2819 2820 Here is an example of a standard reader function a user can define. This 2821 function enables both dataset shuffling and parallel reading of datasets: 2822 2823 ```python 2824 def user_reader_func(datasets): 2825 # shuffle the datasets splits 2826 datasets = datasets.shuffle(NUM_CORES) 2827 # read datasets in parallel and interleave their elements 2828 return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE) 2829 2830 dataset = dataset.snapshot("/path/to/snapshot/dir", 2831 reader_func=user_reader_func) 2832 ``` 2833 2834 By default, snapshot parallelizes reads by the number of cores available on 2835 the system, but will not attempt to shuffle the data. 2836 2837 Args: 2838 path: Required. A directory to use for storing / loading the snapshot to / 2839 from. 2840 compression: Optional. The type of compression to apply to the snapshot 2841 written to disk. Supported options are `GZIP`, `SNAPPY`, `AUTO` or None. 2842 Defaults to `AUTO`, which attempts to pick an appropriate compression 2843 algorithm for the dataset. 2844 reader_func: Optional. A function to control how to read data from 2845 snapshot shards. 2846 shard_func: Optional. A function to control how to shard data when writing 2847 a snapshot. 2848 2849 Returns: 2850 A `Dataset`. 2851 """ 2852 2853 project_func = None 2854 input_dataset = self 2855 if shard_func is None: 2856 input_dataset = input_dataset.enumerate() 2857 # This sets the amount of parallelism based on the number of CPU cores on 2858 # the machine where this Python code is executed, which may differ from 2859 # the number of CPU cores where the input pipeline graph is actually 2860 # executed (e.g. remote Cloud TPU workers). 2861 local_shard_func = lambda index, _: index % multiprocessing.cpu_count() 2862 project_func = lambda _, elem: elem 2863 else: 2864 local_shard_func = shard_func 2865 dataset = _SnapshotDataset( 2866 input_dataset=input_dataset, 2867 path=path, 2868 compression=compression, 2869 reader_func=reader_func, 2870 # This will not do the right thing where the graph is built on a 2871 # different machine than the executor (e.g. Cloud TPUs). 2872 shard_func=local_shard_func) 2873 if project_func is not None: 2874 dataset = dataset.map(project_func) 2875 return dataset 2876 2877 def scan(self, initial_state, scan_func): 2878 """A transformation that scans a function across an input dataset. 2879 2880 This transformation is a stateful relative of `tf.data.Dataset.map`. 2881 In addition to mapping `scan_func` across the elements of the input dataset, 2882 `scan()` accumulates one or more state tensors, whose initial values are 2883 `initial_state`. 2884 2885 >>> dataset = tf.data.Dataset.range(10) 2886 >>> initial_state = tf.constant(0, dtype=tf.int64) 2887 >>> scan_func = lambda state, i: (state + i, state + i) 2888 >>> dataset = dataset.scan(initial_state=initial_state, scan_func=scan_func) 2889 >>> list(dataset.as_numpy_iterator()) 2890 [0, 1, 3, 6, 10, 15, 21, 28, 36, 45] 2891 2892 Args: 2893 initial_state: A nested structure of tensors, representing the initial 2894 state of the accumulator. 2895 scan_func: A function that maps `(old_state, input_element)` to 2896 `(new_state, output_element)`. It must take two arguments and return a 2897 pair of nested structures of tensors. The `new_state` must match the 2898 structure of `initial_state`. 2899 2900 Returns: 2901 A `Dataset`. 2902 """ 2903 2904 return _ScanDataset(self, initial_state=initial_state, scan_func=scan_func) 2905 2906 def take_while(self, predicate): 2907 """A transformation that stops dataset iteration based on a `predicate`. 2908 2909 >>> dataset = tf.data.Dataset.range(10) 2910 >>> dataset = dataset.take_while(lambda x: x < 5) 2911 >>> list(dataset.as_numpy_iterator()) 2912 [0, 1, 2, 3, 4] 2913 2914 Args: 2915 predicate: A function that maps a nested structure of tensors (having 2916 shapes and types defined by `self.output_shapes` and 2917 `self.output_types`) to a scalar `tf.bool` tensor. 2918 2919 Returns: 2920 A `Dataset`. 2921 """ 2922 2923 return _TakeWhileDataset(self, predicate) 2924 2925 def unique(self): 2926 """A transformation that discards duplicate elements of a `Dataset`. 2927 2928 Use this transformation to produce a dataset that contains one instance of 2929 each unique element in the input. For example: 2930 2931 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1]) 2932 >>> dataset = dataset.unique() 2933 >>> sorted(list(dataset.as_numpy_iterator())) 2934 [1, 2, 37] 2935 2936 Note: This transformation only supports datasets which fit into memory 2937 and have elements of either `tf.int32`, `tf.int64` or `tf.string` type. 2938 2939 Returns: 2940 A `Dataset`. 2941 """ 2942 2943 return _UniqueDataset(self) 2944 2945 def rejection_resample(self, 2946 class_func, 2947 target_dist, 2948 initial_dist=None, 2949 seed=None): 2950 """A transformation that resamples a dataset to achieve a target distribution. 2951 2952 Lets consider the following example where a dataset with an initial data 2953 distribution of `init_dist` needs to be resampled into a dataset with 2954 `target_dist` distribution. 2955 2956 >>> import collections 2957 >>> initial_dist = [0.5, 0.5] 2958 >>> target_dist = [0.6, 0.4] 2959 >>> num_classes = len(initial_dist) 2960 >>> num_samples = 100000 2961 >>> data_np = np.random.choice(num_classes, num_samples, p=initial_dist) 2962 >>> dataset = tf.data.Dataset.from_tensor_slices(data_np) 2963 >>> x = collections.defaultdict(int) 2964 >>> for i in dataset: 2965 ... x[i.numpy()] += 1 2966 2967 The value of `x` will be close to `{0: 50000, 1: 50000}` as per the 2968 `initial_dist` distribution. 2969 2970 >>> dataset = dataset.rejection_resample( 2971 ... class_func=lambda x: x % 2, 2972 ... target_dist=target_dist, 2973 ... initial_dist=initial_dist) 2974 2975 >>> y = collections.defaultdict(int) 2976 >>> for i in dataset: 2977 ... cls, _ = i 2978 ... y[cls.numpy()] += 1 2979 2980 The value of `y` will be now be close to `{0: 75000, 1: 50000}` thus 2981 satisfying the `target_dist` distribution. 2982 2983 Args: 2984 class_func: A function mapping an element of the input dataset to a scalar 2985 `tf.int32` tensor. Values should be in `[0, num_classes)`. 2986 target_dist: A floating point type tensor, shaped `[num_classes]`. 2987 initial_dist: (Optional.) A floating point type tensor, shaped 2988 `[num_classes]`. If not provided, the true class distribution is 2989 estimated live in a streaming fashion. 2990 seed: (Optional.) Python integer seed for the resampler. 2991 2992 Returns: 2993 A `Dataset` 2994 """ 2995 2996 target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") 2997 target_dist_t = math_ops.cast(target_dist_t, dtypes.float32) 2998 2999 # Get initial distribution. 3000 if initial_dist is not None: 3001 initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") 3002 initial_dist_t = math_ops.cast(initial_dist_t, dtypes.float32) 3003 acceptance_dist, prob_of_original = ( 3004 _calculate_acceptance_probs_with_mixing(initial_dist_t, 3005 target_dist_t)) 3006 initial_dist_ds = DatasetV2.from_tensors(initial_dist_t).repeat() 3007 acceptance_dist_ds = DatasetV2.from_tensors(acceptance_dist).repeat() 3008 prob_of_original_ds = DatasetV2.from_tensors(prob_of_original).repeat() 3009 else: 3010 initial_dist_ds = _estimate_initial_dist_ds(target_dist_t, 3011 self.map(class_func)) 3012 acceptance_and_original_prob_ds = initial_dist_ds.map( 3013 lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda 3014 initial, target_dist_t)) 3015 acceptance_dist_ds = acceptance_and_original_prob_ds.map( 3016 lambda accept_prob, _: accept_prob) 3017 prob_of_original_ds = acceptance_and_original_prob_ds.map( 3018 lambda _, prob_original: prob_original) 3019 filtered_ds = _filter_ds(self, acceptance_dist_ds, initial_dist_ds, 3020 class_func, seed) 3021 # Prefetch filtered dataset for speed. 3022 filtered_ds = filtered_ds.prefetch(3) 3023 3024 prob_original_static = _get_prob_original_static( 3025 initial_dist_t, target_dist_t) if initial_dist is not None else None 3026 3027 def add_class_value(*x): 3028 if len(x) == 1: 3029 return class_func(*x), x[0] 3030 else: 3031 return class_func(*x), x 3032 3033 if prob_original_static == 1: 3034 return self.map(add_class_value) 3035 elif prob_original_static == 0: 3036 return filtered_ds 3037 else: 3038 return interleave_ops.sample_from_datasets( 3039 [self.map(add_class_value), filtered_ds], 3040 weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]), 3041 seed=seed, 3042 stop_on_empty_dataset=True) 3043 3044 3045@tf_export(v1=["data.Dataset"]) 3046class DatasetV1(DatasetV2): 3047 """Represents a potentially large set of elements. 3048 3049 A `Dataset` can be used to represent an input pipeline as a 3050 collection of elements and a "logical plan" of transformations that act on 3051 those elements. 3052 """ 3053 3054 def __init__(self): 3055 try: 3056 variant_tensor = self._as_variant_tensor() 3057 except AttributeError as e: 3058 if "_as_variant_tensor" in str(e): 3059 raise AttributeError("Please use _variant_tensor instead of " 3060 "_as_variant_tensor() to obtain the variant " 3061 "associated with a dataset") 3062 raise AttributeError("{}: A likely cause of this error is that the super " 3063 "call for this dataset is not the last line of the " 3064 "__init__ method. The base class causes the " 3065 "_as_variant_tensor call in its constructor and " 3066 "if that uses attributes defined in the __init__ " 3067 "method, those attrs need to be defined before the " 3068 "super call.".format(e)) 3069 super(DatasetV1, self).__init__(variant_tensor) 3070 3071 @abc.abstractmethod 3072 def _as_variant_tensor(self): 3073 """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset. 3074 3075 Returns: 3076 A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset. 3077 """ 3078 raise NotImplementedError("Dataset._as_variant_tensor") 3079 3080 @deprecation.deprecated( 3081 None, "This is a deprecated API that should only be used in TF 1 graph " 3082 "mode and legacy TF 2 graph mode available through `tf.compat.v1`. In " 3083 "all other situations -- namely, eager mode and inside `tf.function` -- " 3084 "you can consume dataset elements using `for elem in dataset: ...` or " 3085 "by explicitly creating iterator via `iterator = iter(dataset)` and " 3086 "fetching its elements via `values = next(iterator)`. Furthermore, " 3087 "this API is not available in TF 2. During the transition from TF 1 " 3088 "to TF 2 you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)` " 3089 "to create a TF 1 graph mode style iterator for a dataset created " 3090 "through TF 2 APIs. Note that this should be a transient state of your " 3091 "code base as there are in general no guarantees about the " 3092 "interoperability of TF 1 and TF 2 code.") 3093 def make_one_shot_iterator(self): 3094 """Creates an iterator for elements of this dataset. 3095 3096 Note: The returned iterator will be initialized automatically. 3097 A "one-shot" iterator does not currently support re-initialization. For 3098 that see `make_initializable_iterator`. 3099 3100 Example: 3101 3102 ```python 3103 # Building graph ... 3104 dataset = ... 3105 next_value = dataset.make_one_shot_iterator().get_next() 3106 3107 # ... from within a session ... 3108 try: 3109 while True: 3110 value = sess.run(next_value) 3111 ... 3112 except tf.errors.OutOfRangeError: 3113 pass 3114 ``` 3115 3116 Returns: 3117 An `tf.data.Iterator` for elements of this dataset. 3118 """ 3119 return self._make_one_shot_iterator() 3120 3121 def _make_one_shot_iterator(self): # pylint: disable=missing-docstring 3122 if context.executing_eagerly(): 3123 with ops.colocate_with(self._variant_tensor): 3124 return iterator_ops.OwnedIterator(self) 3125 3126 _ensure_same_dataset_graph(self) 3127 # Some ops (e.g. dataset ops) are marked as stateful but are stil safe to 3128 # to capture by value. We must allowlist these ops so that the capturing 3129 # logic captures the ops instead of raising an exception. 3130 allowlisted_stateful_ops = traverse.obtain_capture_by_value_ops(self) 3131 graph_level_seed, op_level_seed = core_random_seed.get_seed(None) 3132 3133 # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is 3134 # a 0-argument function. 3135 @function.Defun( 3136 capture_by_value=True, 3137 allowlisted_stateful_ops=allowlisted_stateful_ops) 3138 def _make_dataset(): 3139 """Factory function for a dataset.""" 3140 # NOTE(mrry): `Defun` does not capture the graph-level seed from the 3141 # enclosing graph, so if a graph-level seed is present we set the local 3142 # graph seed based on a combination of the graph- and op-level seeds. 3143 if graph_level_seed is not None: 3144 assert op_level_seed is not None 3145 core_random_seed.set_random_seed( 3146 (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1)) 3147 3148 dataset = self._apply_debug_options() 3149 return dataset._variant_tensor # pylint: disable=protected-access 3150 3151 try: 3152 _make_dataset.add_to_graph(ops.get_default_graph()) 3153 except ValueError as err: 3154 if "Cannot capture a stateful node" in str(err): 3155 raise ValueError( 3156 "Failed to create a one-shot iterator for a dataset. " 3157 "`Dataset.make_one_shot_iterator()` does not support datasets that " 3158 "capture stateful objects, such as a `Variable` or `LookupTable`. " 3159 "In these cases, use `Dataset.make_initializable_iterator()`. " 3160 "(Original error: %s)" % err) 3161 else: 3162 six.reraise(ValueError, err) 3163 3164 with ops.colocate_with(self._variant_tensor): 3165 # pylint: disable=protected-access 3166 return iterator_ops.Iterator( 3167 gen_dataset_ops.one_shot_iterator( 3168 dataset_factory=_make_dataset, **self._flat_structure), None, 3169 get_legacy_output_types(self), get_legacy_output_shapes(self), 3170 get_legacy_output_classes(self)) 3171 3172 @deprecation.deprecated( 3173 None, "This is a deprecated API that should only be used in TF 1 graph " 3174 "mode and legacy TF 2 graph mode available through `tf.compat.v1`. " 3175 "In all other situations -- namely, eager mode and inside `tf.function` " 3176 "-- you can consume dataset elements using `for elem in dataset: ...` " 3177 "or by explicitly creating iterator via `iterator = iter(dataset)` " 3178 "and fetching its elements via `values = next(iterator)`. " 3179 "Furthermore, this API is not available in TF 2. During the transition " 3180 "from TF 1 to TF 2 you can use " 3181 "`tf.compat.v1.data.make_initializable_iterator(dataset)` to create a TF " 3182 "1 graph mode style iterator for a dataset created through TF 2 APIs. " 3183 "Note that this should be a transient state of your code base as there " 3184 "are in general no guarantees about the interoperability of TF 1 and TF " 3185 "2 code.") 3186 def make_initializable_iterator(self, shared_name=None): 3187 """Creates an iterator for elements of this dataset. 3188 3189 Note: The returned iterator will be in an uninitialized state, 3190 and you must run the `iterator.initializer` operation before using it: 3191 3192 ```python 3193 # Building graph ... 3194 dataset = ... 3195 iterator = dataset.make_initializable_iterator() 3196 next_value = iterator.get_next() # This is a Tensor. 3197 3198 # ... from within a session ... 3199 sess.run(iterator.initializer) 3200 try: 3201 while True: 3202 value = sess.run(next_value) 3203 ... 3204 except tf.errors.OutOfRangeError: 3205 pass 3206 ``` 3207 3208 Args: 3209 shared_name: (Optional.) If non-empty, the returned iterator will be 3210 shared under the given name across multiple sessions that share the same 3211 devices (e.g. when using a remote server). 3212 3213 Returns: 3214 A `tf.data.Iterator` for elements of this dataset. 3215 3216 Raises: 3217 RuntimeError: If eager execution is enabled. 3218 """ 3219 return self._make_initializable_iterator(shared_name) 3220 3221 def _make_initializable_iterator(self, shared_name=None): # pylint: disable=missing-docstring 3222 if context.executing_eagerly(): 3223 raise RuntimeError( 3224 "dataset.make_initializable_iterator is not supported when eager " 3225 "execution is enabled. Use `for element in dataset` instead.") 3226 _ensure_same_dataset_graph(self) 3227 dataset = self._apply_debug_options() 3228 if shared_name is None: 3229 shared_name = "" 3230 3231 with ops.colocate_with(self._variant_tensor): 3232 iterator_resource = gen_dataset_ops.iterator_v2( 3233 container="", shared_name=shared_name, **self._flat_structure) 3234 3235 initializer = gen_dataset_ops.make_iterator( 3236 dataset._variant_tensor, # pylint: disable=protected-access 3237 iterator_resource) 3238 3239 # pylint: disable=protected-access 3240 return iterator_ops.Iterator(iterator_resource, initializer, 3241 get_legacy_output_types(dataset), 3242 get_legacy_output_shapes(dataset), 3243 get_legacy_output_classes(dataset)) 3244 3245 @property 3246 @deprecation.deprecated( 3247 None, "Use `tf.compat.v1.data.get_output_classes(dataset)`.") 3248 def output_classes(self): 3249 """Returns the class of each component of an element of this dataset. 3250 3251 Returns: 3252 A (nested) structure of Python `type` objects corresponding to each 3253 component of an element of this dataset. 3254 """ 3255 return nest.map_structure( 3256 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 3257 self.element_spec) 3258 3259 @property 3260 @deprecation.deprecated( 3261 None, "Use `tf.compat.v1.data.get_output_shapes(dataset)`.") 3262 def output_shapes(self): 3263 """Returns the shape of each component of an element of this dataset. 3264 3265 Returns: 3266 A (nested) structure of `tf.TensorShape` objects corresponding to each 3267 component of an element of this dataset. 3268 """ 3269 return nest.map_structure( 3270 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 3271 self.element_spec) 3272 3273 @property 3274 @deprecation.deprecated( 3275 None, "Use `tf.compat.v1.data.get_output_types(dataset)`.") 3276 def output_types(self): 3277 """Returns the type of each component of an element of this dataset. 3278 3279 Returns: 3280 A (nested) structure of `tf.DType` objects corresponding to each component 3281 of an element of this dataset. 3282 """ 3283 return nest.map_structure( 3284 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 3285 self.element_spec) 3286 3287 @property 3288 def element_spec(self): 3289 # TODO(b/110122868): Remove this override once all `Dataset` instances 3290 # implement `element_structure`. 3291 return structure.convert_legacy_structure( 3292 self.output_types, self.output_shapes, self.output_classes) 3293 3294 @staticmethod 3295 @functools.wraps(DatasetV2.from_tensors) 3296 def from_tensors(tensors): 3297 return DatasetV1Adapter(DatasetV2.from_tensors(tensors)) 3298 3299 @staticmethod 3300 @functools.wraps(DatasetV2.from_tensor_slices) 3301 def from_tensor_slices(tensors): 3302 return DatasetV1Adapter(DatasetV2.from_tensor_slices(tensors)) 3303 3304 @staticmethod 3305 @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") 3306 def from_sparse_tensor_slices(sparse_tensor): 3307 """Splits each rank-N `tf.sparse.SparseTensor` in this dataset row-wise. 3308 3309 Args: 3310 sparse_tensor: A `tf.sparse.SparseTensor`. 3311 3312 Returns: 3313 Dataset: A `Dataset` of rank-(N-1) sparse tensors. 3314 """ 3315 return DatasetV1Adapter(SparseTensorSliceDataset(sparse_tensor)) 3316 3317 @staticmethod 3318 @functools.wraps(DatasetV2.from_generator) 3319 @deprecation.deprecated_args(None, "Use output_signature instead", 3320 "output_types", "output_shapes") 3321 def from_generator(generator, 3322 output_types=None, 3323 output_shapes=None, 3324 args=None, 3325 output_signature=None): 3326 # Calling DatasetV2.from_generator with output_shapes or output_types is 3327 # deprecated, but this is already checked by the decorator on this function. 3328 with deprecation.silence(): 3329 return DatasetV1Adapter( 3330 DatasetV2.from_generator(generator, output_types, output_shapes, args, 3331 output_signature)) 3332 3333 @staticmethod 3334 @functools.wraps(DatasetV2.range) 3335 def range(*args, **kwargs): 3336 return DatasetV1Adapter(DatasetV2.range(*args, **kwargs)) 3337 3338 @staticmethod 3339 @functools.wraps(DatasetV2.zip) 3340 def zip(datasets): 3341 return DatasetV1Adapter(DatasetV2.zip(datasets)) 3342 3343 @functools.wraps(DatasetV2.concatenate) 3344 def concatenate(self, dataset): 3345 return DatasetV1Adapter(super(DatasetV1, self).concatenate(dataset)) 3346 3347 @functools.wraps(DatasetV2.prefetch) 3348 def prefetch(self, buffer_size): 3349 return DatasetV1Adapter(super(DatasetV1, self).prefetch(buffer_size)) 3350 3351 @staticmethod 3352 @functools.wraps(DatasetV2.list_files) 3353 def list_files(file_pattern, shuffle=None, seed=None): 3354 return DatasetV1Adapter(DatasetV2.list_files(file_pattern, shuffle, seed)) 3355 3356 @functools.wraps(DatasetV2.repeat) 3357 def repeat(self, count=None): 3358 return DatasetV1Adapter(super(DatasetV1, self).repeat(count)) 3359 3360 @functools.wraps(DatasetV2.shuffle) 3361 def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None): 3362 return DatasetV1Adapter(super(DatasetV1, self).shuffle( 3363 buffer_size, seed, reshuffle_each_iteration)) 3364 3365 @functools.wraps(DatasetV2.cache) 3366 def cache(self, filename=""): 3367 return DatasetV1Adapter(super(DatasetV1, self).cache(filename)) 3368 3369 @functools.wraps(DatasetV2.take) 3370 def take(self, count): 3371 return DatasetV1Adapter(super(DatasetV1, self).take(count)) 3372 3373 @functools.wraps(DatasetV2.skip) 3374 def skip(self, count): 3375 return DatasetV1Adapter(super(DatasetV1, self).skip(count)) 3376 3377 @functools.wraps(DatasetV2.shard) 3378 def shard(self, num_shards, index): 3379 return DatasetV1Adapter(super(DatasetV1, self).shard(num_shards, index)) 3380 3381 @functools.wraps(DatasetV2.batch) 3382 def batch(self, 3383 batch_size, 3384 drop_remainder=False, 3385 num_parallel_calls=None, 3386 deterministic=None): 3387 return DatasetV1Adapter( 3388 super(DatasetV1, self).batch(batch_size, drop_remainder, 3389 num_parallel_calls, deterministic)) 3390 3391 @functools.wraps(DatasetV2.padded_batch) 3392 def padded_batch(self, 3393 batch_size, 3394 padded_shapes=None, 3395 padding_values=None, 3396 drop_remainder=False): 3397 return DatasetV1Adapter( 3398 super(DatasetV1, self).padded_batch(batch_size, padded_shapes, 3399 padding_values, drop_remainder)) 3400 3401 @functools.wraps(DatasetV2.map) 3402 def map(self, map_func, num_parallel_calls=None, deterministic=None): 3403 if num_parallel_calls is None or DEBUG_MODE: 3404 return DatasetV1Adapter( 3405 MapDataset(self, map_func, preserve_cardinality=False)) 3406 else: 3407 return DatasetV1Adapter( 3408 ParallelMapDataset( 3409 self, 3410 map_func, 3411 num_parallel_calls, 3412 deterministic, 3413 preserve_cardinality=False)) 3414 3415 @deprecation.deprecated(None, "Use `tf.data.Dataset.map()") 3416 def map_with_legacy_function(self, 3417 map_func, 3418 num_parallel_calls=None, 3419 deterministic=None): 3420 """Maps `map_func` across the elements of this dataset. 3421 3422 Note: This is an escape hatch for existing uses of `map` that do not work 3423 with V2 functions. New uses are strongly discouraged and existing uses 3424 should migrate to `map` as this method will be removed in V2. 3425 3426 Args: 3427 map_func: A function mapping a (nested) structure of tensors (having 3428 shapes and types defined by `self.output_shapes` and 3429 `self.output_types`) to another (nested) structure of tensors. 3430 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 3431 representing the number elements to process asynchronously in parallel. 3432 If not specified, elements will be processed sequentially. If the value 3433 `tf.data.AUTOTUNE` is used, then the number of parallel 3434 calls is set dynamically based on available CPU. 3435 deterministic: (Optional.) When `num_parallel_calls` is specified, this 3436 boolean controls the order in which the transformation produces 3437 elements. If set to `False`, the transformation is allowed to yield 3438 elements out of order to trade determinism for performance. If not 3439 specified, the `tf.data.Options.deterministic` option (`True` by 3440 default) controls the behavior. 3441 3442 Returns: 3443 Dataset: A `Dataset`. 3444 """ 3445 if num_parallel_calls is None: 3446 if deterministic is not None: 3447 warnings.warn("The `deterministic` argument has no effect unless the " 3448 "`num_parallel_calls` argument is specified.") 3449 return DatasetV1Adapter( 3450 MapDataset( 3451 self, 3452 map_func, 3453 preserve_cardinality=False, 3454 use_legacy_function=True)) 3455 else: 3456 return DatasetV1Adapter( 3457 ParallelMapDataset( 3458 self, 3459 map_func, 3460 num_parallel_calls, 3461 deterministic, 3462 preserve_cardinality=False, 3463 use_legacy_function=True)) 3464 3465 @functools.wraps(DatasetV2.flat_map) 3466 def flat_map(self, map_func): 3467 return DatasetV1Adapter(super(DatasetV1, self).flat_map(map_func)) 3468 3469 @functools.wraps(DatasetV2.interleave) 3470 def interleave(self, 3471 map_func, 3472 cycle_length=None, 3473 block_length=None, 3474 num_parallel_calls=None, 3475 deterministic=None): 3476 return DatasetV1Adapter( 3477 super(DatasetV1, self).interleave(map_func, cycle_length, block_length, 3478 num_parallel_calls, deterministic)) 3479 3480 @functools.wraps(DatasetV2.filter) 3481 def filter(self, predicate): 3482 return DatasetV1Adapter(super(DatasetV1, self).filter(predicate)) 3483 3484 @deprecation.deprecated(None, "Use `tf.data.Dataset.filter()") 3485 def filter_with_legacy_function(self, predicate): 3486 """Filters this dataset according to `predicate`. 3487 3488 Note: This is an escape hatch for existing uses of `filter` that do not work 3489 with V2 functions. New uses are strongly discouraged and existing uses 3490 should migrate to `filter` as this method will be removed in V2. 3491 3492 Args: 3493 predicate: A function mapping a (nested) structure of tensors (having 3494 shapes and types defined by `self.output_shapes` and 3495 `self.output_types`) to a scalar `tf.bool` tensor. 3496 3497 Returns: 3498 Dataset: The `Dataset` containing the elements of this dataset for which 3499 `predicate` is `True`. 3500 """ 3501 return FilterDataset(self, predicate, use_legacy_function=True) 3502 3503 @functools.wraps(DatasetV2.apply) 3504 def apply(self, transformation_func): 3505 return DatasetV1Adapter(super(DatasetV1, self).apply(transformation_func)) 3506 3507 @functools.wraps(DatasetV2.window) 3508 def window(self, size, shift=None, stride=1, drop_remainder=False): 3509 return DatasetV1Adapter(super(DatasetV1, self).window( 3510 size, shift, stride, drop_remainder)) 3511 3512 @functools.wraps(DatasetV2.unbatch) 3513 def unbatch(self): 3514 return DatasetV1Adapter(super(DatasetV1, self).unbatch()) 3515 3516 @functools.wraps(DatasetV2.with_options) 3517 def with_options(self, options): 3518 return DatasetV1Adapter(super(DatasetV1, self).with_options(options)) 3519 3520 3521if tf2.enabled(): 3522 Dataset = DatasetV2 3523else: 3524 Dataset = DatasetV1 3525 3526 3527class DatasetV1Adapter(DatasetV1): 3528 """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API.""" 3529 3530 def __init__(self, dataset): 3531 self._dataset = dataset 3532 super(DatasetV1Adapter, self).__init__() 3533 3534 def _as_variant_tensor(self): 3535 return self._dataset._variant_tensor # pylint: disable=protected-access 3536 3537 def _inputs(self): 3538 return self._dataset._inputs() # pylint: disable=protected-access 3539 3540 def _functions(self): 3541 return self._dataset._functions() # pylint: disable=protected-access 3542 3543 def options(self): 3544 return self._dataset.options() 3545 3546 @property 3547 def element_spec(self): 3548 return self._dataset.element_spec # pylint: disable=protected-access 3549 3550 def __iter__(self): 3551 return iter(self._dataset) 3552 3553 3554def _ensure_same_dataset_graph(dataset): 3555 """Walks the dataset graph to ensure all datasets come from the same graph.""" 3556 # pylint: disable=protected-access 3557 current_graph = ops.get_default_graph() 3558 bfs_q = Queue.Queue() 3559 bfs_q.put(dataset) 3560 visited = [] 3561 while not bfs_q.empty(): 3562 ds = bfs_q.get() 3563 visited.append(ds) 3564 ds_graph = ds._graph 3565 if current_graph != ds_graph: 3566 raise ValueError( 3567 "The graph (" + str(current_graph) + ") of the iterator is different " 3568 "from the graph (" + str(ds_graph) + ") the dataset: " + 3569 str(ds._variant_tensor) + " was created in. If you are using the " 3570 "Estimator API, make sure that no part of the dataset returned by " 3571 "the `input_fn` function is defined outside the `input_fn` function. " 3572 "Please ensure that all datasets in the pipeline are created in the " 3573 "same graph as the iterator.") 3574 for input_ds in ds._inputs(): 3575 if input_ds not in visited: 3576 bfs_q.put(input_ds) 3577 3578 3579@tf_export(v1=["data.make_one_shot_iterator"]) 3580def make_one_shot_iterator(dataset): 3581 """Creates an iterator for elements of `dataset`. 3582 3583 Note: The returned iterator will be initialized automatically. 3584 A "one-shot" iterator does not support re-initialization. 3585 3586 Args: 3587 dataset: A `tf.data.Dataset`. 3588 3589 Returns: 3590 A `tf.data.Iterator` for elements of `dataset`. 3591 3592 @compatibility(TF2) 3593 This is a legacy API for consuming dataset elements and should only be used 3594 during transition from TF 1 to TF 2. Note that using this API should be 3595 a transient state of your code base as there are in general no guarantees 3596 about the interoperability of TF 1 and TF 2 code. 3597 3598 In TF 2 datasets are Python iterables which means you can consume their 3599 elements using `for elem in dataset: ...` or by explicitly creating iterator 3600 via `iterator = iter(dataset)` and fetching its elements via 3601 `values = next(iterator)`. 3602 @end_compatibility 3603 """ 3604 try: 3605 # Call the defined `_make_one_shot_iterator()` if there is one, because some 3606 # datasets (e.g. for prefetching) override its behavior. 3607 return dataset._make_one_shot_iterator() # pylint: disable=protected-access 3608 except AttributeError: 3609 return DatasetV1Adapter(dataset)._make_one_shot_iterator() # pylint: disable=protected-access 3610 3611 3612@tf_export(v1=["data.make_initializable_iterator"]) 3613def make_initializable_iterator(dataset, shared_name=None): 3614 """Creates an iterator for elements of `dataset`. 3615 3616 Note: The returned iterator will be in an uninitialized state, 3617 and you must run the `iterator.initializer` operation before using it: 3618 3619 ```python 3620 dataset = ... 3621 iterator = tf.compat.v1.data.make_initializable_iterator(dataset) 3622 # ... 3623 sess.run(iterator.initializer) 3624 ``` 3625 3626 Args: 3627 dataset: A `tf.data.Dataset`. 3628 shared_name: (Optional.) If non-empty, the returned iterator will be shared 3629 under the given name across multiple sessions that share the same devices 3630 (e.g. when using a remote server). 3631 3632 Returns: 3633 A `tf.data.Iterator` for elements of `dataset`. 3634 3635 Raises: 3636 RuntimeError: If eager execution is enabled. 3637 3638 @compatibility(TF2) 3639 This is a legacy API for consuming dataset elements and should only be used 3640 during transition from TF 1 to TF 2. Note that using this API should be 3641 a transient state of your code base as there are in general no guarantees 3642 about the interoperability of TF 1 and TF 2 code. 3643 3644 In TF 2 datasets are Python iterables which means you can consume their 3645 elements using `for elem in dataset: ...` or by explicitly creating iterator 3646 via `iterator = iter(dataset)` and fetching its elements via 3647 `values = next(iterator)`. 3648 @end_compatibility 3649 """ 3650 try: 3651 # Call the defined `_make_initializable_iterator()` if there is one, because 3652 # some datasets (e.g. for prefetching) override its behavior. 3653 return dataset._make_initializable_iterator(shared_name) # pylint: disable=protected-access 3654 except AttributeError: 3655 return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name) # pylint: disable=protected-access 3656 3657 3658@tf_export("data.experimental.get_structure") 3659def get_structure(dataset_or_iterator): 3660 """Returns the type signature for elements of the input dataset / iterator. 3661 3662 Args: 3663 dataset_or_iterator: A `tf.data.Dataset` or an `tf.data.Iterator`. 3664 3665 Returns: 3666 A (nested) structure of `tf.TypeSpec` objects matching the structure of an 3667 element of `dataset_or_iterator` and specifying the type of individual 3668 components. 3669 3670 Raises: 3671 TypeError: If input is not a `tf.data.Dataset` or an `tf.data.Iterator` 3672 object. 3673 """ 3674 try: 3675 return dataset_or_iterator.element_spec # pylint: disable=protected-access 3676 except AttributeError: 3677 raise TypeError("`dataset_or_iterator` must be a `tf.data.Dataset` or " 3678 "tf.data.Iterator object, but got %s." % 3679 type(dataset_or_iterator)) 3680 3681 3682@tf_export(v1=["data.get_output_classes"]) 3683def get_legacy_output_classes(dataset_or_iterator): 3684 """Returns the output classes for elements of the input dataset / iterator. 3685 3686 Args: 3687 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 3688 3689 Returns: 3690 A (nested) structure of Python `type` objects matching the structure of the 3691 dataset / iterator elements and specifying the class of the individual 3692 components. 3693 3694 @compatibility(TF2) 3695 This is a legacy API for inspecting the type signature of dataset elements. In 3696 TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead. 3697 @end_compatibility 3698 """ 3699 return nest.map_structure( 3700 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 3701 get_structure(dataset_or_iterator)) 3702 3703 3704@tf_export(v1=["data.get_output_shapes"]) 3705def get_legacy_output_shapes(dataset_or_iterator): 3706 """Returns the output shapes for elements of the input dataset / iterator. 3707 3708 Args: 3709 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 3710 3711 Returns: 3712 A (nested) structure of `tf.TensorShape` objects matching the structure of 3713 the dataset / iterator elements and specifying the shape of the individual 3714 components. 3715 3716 @compatibility(TF2) 3717 This is a legacy API for inspecting the type signature of dataset elements. In 3718 TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead. 3719 @end_compatibility 3720 """ 3721 return nest.map_structure( 3722 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 3723 get_structure(dataset_or_iterator)) 3724 3725 3726@tf_export(v1=["data.get_output_types"]) 3727def get_legacy_output_types(dataset_or_iterator): 3728 """Returns the output shapes for elements of the input dataset / iterator. 3729 3730 Args: 3731 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 3732 3733 Returns: 3734 A (nested) structure of `tf.DType` objects matching the structure of 3735 dataset / iterator elements and specifying the shape of the individual 3736 components. 3737 3738 @compatibility(TF2) 3739 This is a legacy API for inspecting the type signature of dataset elements. In 3740 TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead. 3741 @end_compatibility 3742 """ 3743 return nest.map_structure( 3744 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 3745 get_structure(dataset_or_iterator)) 3746 3747 3748class DatasetSource(DatasetV2): 3749 """Abstract class representing a dataset with no inputs.""" 3750 3751 def _inputs(self): 3752 return [] 3753 3754 3755class UnaryDataset(DatasetV2): 3756 """Abstract class representing a dataset with one input.""" 3757 3758 def __init__(self, input_dataset, variant_tensor): 3759 self._input_dataset = input_dataset 3760 super(UnaryDataset, self).__init__(variant_tensor) 3761 3762 def _inputs(self): 3763 return [self._input_dataset] 3764 3765 3766class UnaryUnchangedStructureDataset(UnaryDataset): 3767 """Represents a unary dataset with the same input and output structure.""" 3768 3769 def __init__(self, input_dataset, variant_tensor): 3770 self._input_dataset = input_dataset 3771 super(UnaryUnchangedStructureDataset, self).__init__( 3772 input_dataset, variant_tensor) 3773 3774 @property 3775 def element_spec(self): 3776 return self._input_dataset.element_spec 3777 3778 3779class TensorDataset(DatasetSource): 3780 """A `Dataset` with a single element.""" 3781 3782 def __init__(self, element): 3783 """See `Dataset.from_tensors()` for details.""" 3784 element = structure.normalize_element(element) 3785 self._structure = structure.type_spec_from_value(element) 3786 self._tensors = structure.to_tensor_list(self._structure, element) 3787 3788 variant_tensor = gen_dataset_ops.tensor_dataset( 3789 self._tensors, 3790 output_shapes=structure.get_flat_tensor_shapes(self._structure)) 3791 super(TensorDataset, self).__init__(variant_tensor) 3792 3793 @property 3794 def element_spec(self): 3795 return self._structure 3796 3797 3798class TensorSliceDataset(DatasetSource): 3799 """A `Dataset` of slices from a dataset element.""" 3800 3801 def __init__(self, element): 3802 """See `Dataset.from_tensor_slices()` for details.""" 3803 element = structure.normalize_element(element) 3804 batched_spec = structure.type_spec_from_value(element) 3805 self._tensors = structure.to_batched_tensor_list(batched_spec, element) 3806 self._structure = nest.map_structure( 3807 lambda component_spec: component_spec._unbatch(), batched_spec) # pylint: disable=protected-access 3808 3809 batch_dim = tensor_shape.Dimension(tensor_shape.dimension_value( 3810 self._tensors[0].get_shape()[0])) 3811 for t in self._tensors[1:]: 3812 batch_dim.assert_is_compatible_with(tensor_shape.Dimension( 3813 tensor_shape.dimension_value(t.get_shape()[0]))) 3814 3815 variant_tensor = gen_dataset_ops.tensor_slice_dataset( 3816 self._tensors, 3817 output_shapes=structure.get_flat_tensor_shapes(self._structure)) 3818 super(TensorSliceDataset, self).__init__(variant_tensor) 3819 3820 @property 3821 def element_spec(self): 3822 return self._structure 3823 3824 3825class SparseTensorSliceDataset(DatasetSource): 3826 """A `Dataset` that splits a rank-N `tf.sparse.SparseTensor` into its rows.""" 3827 3828 def __init__(self, sparse_tensor): 3829 """See `Dataset.from_sparse_tensor_slices()` for details.""" 3830 if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor): 3831 raise TypeError( 3832 "`sparse_tensor` must be a `tf.sparse.SparseTensor` object." 3833 "Was {}.".format(sparse_tensor)) 3834 self._sparse_tensor = sparse_tensor 3835 3836 indices_shape = self._sparse_tensor.indices.get_shape() 3837 shape_shape = self._sparse_tensor.dense_shape.get_shape() 3838 rank = (indices_shape.dims[1] - 1).merge_with(shape_shape.dims[0] - 1) 3839 self._structure = (tensor_spec.TensorSpec([None, rank], dtypes.int64), 3840 tensor_spec.TensorSpec([None], 3841 self._sparse_tensor.dtype), 3842 tensor_spec.TensorSpec([rank], dtypes.int64)) 3843 3844 variant_tensor = gen_dataset_ops.sparse_tensor_slice_dataset( 3845 self._sparse_tensor.indices, self._sparse_tensor.values, 3846 self._sparse_tensor.dense_shape) 3847 super(SparseTensorSliceDataset, self).__init__(variant_tensor) 3848 3849 @property 3850 def element_spec(self): 3851 return self._structure 3852 3853 3854class _VariantDataset(DatasetV2): 3855 """A Dataset wrapper around a `tf.variant`-typed function argument.""" 3856 3857 def __init__(self, dataset_variant, structure): 3858 self._structure = structure 3859 super(_VariantDataset, self).__init__(dataset_variant) 3860 3861 def _inputs(self): 3862 return [] 3863 3864 @property 3865 def element_spec(self): 3866 return self._structure 3867 3868 3869class _NestedVariant(composite_tensor.CompositeTensor): 3870 3871 def __init__(self, variant_tensor, element_spec, dataset_shape): 3872 self._variant_tensor = variant_tensor 3873 self._element_spec = element_spec 3874 self._dataset_shape = dataset_shape 3875 3876 @property 3877 def _type_spec(self): 3878 return DatasetSpec(self._element_spec, self._dataset_shape) 3879 3880 3881@tf_export("data.experimental.from_variant") 3882def from_variant(variant, structure): 3883 """Constructs a dataset from the given variant and (nested) structure. 3884 3885 Args: 3886 variant: A scalar `tf.variant` tensor representing a dataset. 3887 structure: A (nested) structure of `tf.TypeSpec` objects representing the 3888 structure of each element in the dataset. 3889 3890 Returns: 3891 A `tf.data.Dataset` instance. 3892 """ 3893 return _VariantDataset(variant, structure) # pylint: disable=protected-access 3894 3895 3896@tf_export("data.experimental.to_variant") 3897def to_variant(dataset): 3898 """Returns a variant representing the given dataset. 3899 3900 Args: 3901 dataset: A `tf.data.Dataset`. 3902 3903 Returns: 3904 A scalar `tf.variant` tensor representing the given dataset. 3905 """ 3906 return dataset._variant_tensor # pylint: disable=protected-access 3907 3908 3909@tf_export( 3910 "data.DatasetSpec", 3911 v1=["data.DatasetSpec", "data.experimental.DatasetStructure"]) 3912class DatasetSpec(type_spec.BatchableTypeSpec): 3913 """Type specification for `tf.data.Dataset`. 3914 3915 See `tf.TypeSpec` for more information about TensorFlow type specifications. 3916 3917 >>> dataset = tf.data.Dataset.range(3) 3918 >>> tf.data.DatasetSpec.from_value(dataset) 3919 DatasetSpec(TensorSpec(shape=(), dtype=tf.int64, name=None), TensorShape([])) 3920 """ 3921 3922 __slots__ = ["_element_spec", "_dataset_shape"] 3923 3924 def __init__(self, element_spec, dataset_shape=()): 3925 self._element_spec = element_spec 3926 self._dataset_shape = tensor_shape.as_shape(dataset_shape) 3927 3928 @property 3929 def value_type(self): 3930 return Dataset 3931 3932 @property 3933 def element_spec(self): 3934 """The inner element spec.""" 3935 return self._element_spec 3936 3937 def _serialize(self): 3938 return (self._element_spec, self._dataset_shape) 3939 3940 @property 3941 def _component_specs(self): 3942 return tensor_spec.TensorSpec(self._dataset_shape, dtypes.variant) 3943 3944 def _to_components(self, value): 3945 return value._variant_tensor # pylint: disable=protected-access 3946 3947 def _from_components(self, components): 3948 # pylint: disable=protected-access 3949 if self._dataset_shape.ndims == 0: 3950 return _VariantDataset(components, self._element_spec) 3951 else: 3952 return _NestedVariant(components, self._element_spec, self._dataset_shape) 3953 3954 def _to_tensor_list(self, value): 3955 return [ 3956 ops.convert_to_tensor( 3957 tf_nest.map_structure(lambda x: x._variant_tensor, value)) # pylint: disable=protected-access 3958 ] 3959 3960 @staticmethod 3961 def from_value(value): 3962 """Creates a `DatasetSpec` for the given `tf.data.Dataset` value.""" 3963 return DatasetSpec(value.element_spec) # pylint: disable=protected-access 3964 3965 def _batch(self, batch_size): 3966 return DatasetSpec( 3967 self._element_spec, 3968 tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape)) 3969 3970 def _unbatch(self): 3971 if self._dataset_shape.ndims == 0: 3972 raise ValueError("Unbatching a dataset is only supported for rank >= 1") 3973 return DatasetSpec(self._element_spec, self._dataset_shape[1:]) 3974 3975 def _to_batched_tensor_list(self, value): 3976 if self._dataset_shape.ndims == 0: 3977 raise ValueError("Unbatching a dataset is only supported for rank >= 1") 3978 return self._to_tensor_list(value) 3979 3980 def _to_legacy_output_types(self): 3981 return self 3982 3983 def _to_legacy_output_shapes(self): 3984 return self 3985 3986 def _to_legacy_output_classes(self): 3987 return self 3988 3989 3990class StructuredFunctionWrapper(object): 3991 """A function wrapper that supports structured arguments and return values.""" 3992 3993 def __init__(self, 3994 func, 3995 transformation_name, 3996 dataset=None, 3997 input_classes=None, 3998 input_shapes=None, 3999 input_types=None, 4000 input_structure=None, 4001 add_to_graph=True, 4002 use_legacy_function=False, 4003 defun_kwargs=None): 4004 """Creates a new `StructuredFunctionWrapper` for the given function. 4005 4006 Args: 4007 func: A function from a (nested) structure to another (nested) structure. 4008 transformation_name: Human-readable name of the transformation in which 4009 this function is being instantiated, for error messages. 4010 dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this 4011 dataset will be assumed as the structure for `func` arguments; otherwise 4012 `input_classes`, `input_shapes`, and `input_types` must be defined. 4013 input_classes: (Optional.) A (nested) structure of `type`. If given, this 4014 argument defines the Python types for `func` arguments. 4015 input_shapes: (Optional.) A (nested) structure of `tf.TensorShape`. If 4016 given, this argument defines the shapes and structure for `func` 4017 arguments. 4018 input_types: (Optional.) A (nested) structure of `tf.DType`. If given, 4019 this argument defines the element types and structure for `func` 4020 arguments. 4021 input_structure: (Optional.) A `Structure` object. If given, this argument 4022 defines the element types and structure for `func` arguments. 4023 add_to_graph: (Optional.) If `True`, the function will be added to the 4024 default graph, if it exists. 4025 use_legacy_function: (Optional.) A boolean that determines whether the 4026 function be created using `tensorflow.python.eager.function.defun` 4027 (default behavior) or `tensorflow.python.framework.function.Defun` 4028 (legacy behavior). 4029 defun_kwargs: (Optional.) A dictionary mapping string argument names to 4030 values. If supplied, will be passed to `function` as keyword arguments. 4031 4032 Raises: 4033 ValueError: If an invalid combination of `dataset`, `input_classes`, 4034 `input_shapes`, and `input_types` is passed. 4035 """ 4036 # pylint: disable=protected-access 4037 if input_structure is None: 4038 if dataset is None: 4039 if input_classes is None or input_shapes is None or input_types is None: 4040 raise ValueError("Either `dataset`, `input_structure` or all of " 4041 "`input_classes`, `input_shapes`, and `input_types` " 4042 "must be specified.") 4043 self._input_structure = structure.convert_legacy_structure( 4044 input_types, input_shapes, input_classes) 4045 else: 4046 if not (input_classes is None and input_shapes is None and 4047 input_types is None): 4048 raise ValueError("Either `dataset`, `input_structure` or all of " 4049 "`input_classes`, `input_shapes`, and `input_types` " 4050 "must be specified.") 4051 self._input_structure = dataset.element_spec 4052 else: 4053 if not (dataset is None and input_classes is None and input_shapes is None 4054 and input_types is None): 4055 raise ValueError("Either `dataset`, `input_structure`, or all of " 4056 "`input_classes`, `input_shapes`, and `input_types` " 4057 "must be specified.") 4058 self._input_structure = input_structure 4059 4060 self._func = func 4061 4062 if defun_kwargs is None: 4063 defun_kwargs = {} 4064 4065 readable_transformation_name = transformation_name.replace( 4066 ".", "_")[:-2] if len(transformation_name) > 2 else "" 4067 4068 func_name = "_".join( 4069 [readable_transformation_name, 4070 function_utils.get_func_name(func)]) 4071 # Sanitize function name to remove symbols that interfere with graph 4072 # construction. 4073 for symbol in ["<", ">", "\\", "'", " "]: 4074 func_name = func_name.replace(symbol, "") 4075 4076 ag_ctx = autograph_ctx.control_status_ctx() 4077 4078 def wrapper_helper(*args): 4079 """Wrapper for passing nested structures to and from tf.data functions.""" 4080 nested_args = structure.from_compatible_tensor_list( 4081 self._input_structure, args) 4082 if not _should_unpack(nested_args): 4083 nested_args = (nested_args,) 4084 ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args) 4085 if _should_pack(ret): 4086 ret = tuple(ret) 4087 4088 try: 4089 self._output_structure = structure.type_spec_from_value(ret) 4090 except (ValueError, TypeError): 4091 six.reraise( 4092 TypeError, 4093 TypeError("Unsupported return value from function passed to " 4094 "%s: %s." % (transformation_name, ret)), 4095 sys.exc_info()[2]) 4096 return ret 4097 4098 def trace_legacy_function(defun_kwargs): 4099 @function.Defun(*structure.get_flat_tensor_types(self._input_structure), 4100 **defun_kwargs) 4101 def wrapped_fn(*args): 4102 ret = wrapper_helper(*args) 4103 return structure.to_tensor_list(self._output_structure, ret) 4104 4105 return lambda: wrapped_fn 4106 4107 def trace_py_function(defun_kwargs): 4108 # First we trace the function to infer the output structure. 4109 @eager_function.defun_with_attributes( 4110 input_signature=structure.get_flat_tensor_specs( 4111 self._input_structure), 4112 autograph=False, 4113 attributes=defun_kwargs) 4114 def unused(*args): # pylint: disable=missing-docstring,unused-variable 4115 ret = wrapper_helper(*args) 4116 ret = structure.to_tensor_list(self._output_structure, ret) 4117 return [ops.convert_to_tensor(t) for t in ret] 4118 4119 _ = unused.get_concrete_function() 4120 4121 def py_function_wrapper(*args): 4122 nested_args = structure.from_compatible_tensor_list( 4123 self._input_structure, args) 4124 if not _should_unpack(nested_args): 4125 nested_args = (nested_args,) 4126 ret = self._func(*nested_args) 4127 if _should_pack(ret): 4128 ret = tuple(ret) 4129 ret = structure.to_tensor_list(self._output_structure, ret) 4130 return [ops.convert_to_tensor(t) for t in ret] 4131 4132 # Next we trace the function wrapped in `eager_py_func` to force eager 4133 # execution. 4134 @eager_function.defun_with_attributes( 4135 input_signature=structure.get_flat_tensor_specs( 4136 self._input_structure), 4137 autograph=False, 4138 attributes=defun_kwargs) 4139 def wrapped_fn(*args): # pylint: disable=missing-docstring 4140 return script_ops.eager_py_func( 4141 py_function_wrapper, args, 4142 structure.get_flat_tensor_types(self._output_structure)) 4143 4144 return wrapped_fn.get_concrete_function 4145 4146 def trace_tf_function(defun_kwargs): 4147 # Note: wrapper_helper will apply autograph based on context. 4148 @eager_function.defun_with_attributes( 4149 input_signature=structure.get_flat_tensor_specs( 4150 self._input_structure), 4151 autograph=False, 4152 attributes=defun_kwargs) 4153 def wrapped_fn(*args): # pylint: disable=missing-docstring 4154 ret = wrapper_helper(*args) 4155 ret = structure.to_tensor_list(self._output_structure, ret) 4156 return [ops.convert_to_tensor(t) for t in ret] 4157 4158 return wrapped_fn.get_concrete_function 4159 4160 if use_legacy_function: 4161 defun_kwargs.update({"func_name": func_name + "_" + str(ops.uid())}) 4162 fn_factory = trace_legacy_function(defun_kwargs) 4163 else: 4164 defun_kwargs.update({"func_name": func_name}) 4165 defun_kwargs.update({"_tf_data_function": True}) 4166 if DEBUG_MODE: 4167 fn_factory = trace_py_function(defun_kwargs) 4168 else: 4169 if def_function.functions_run_eagerly(): 4170 warnings.warn( 4171 "Even though the `tf.config.experimental_run_functions_eagerly` " 4172 "option is set, this option does not apply to tf.data functions. " 4173 "To force eager execution of tf.data functions, please use " 4174 "`tf.data.experimental.enable_debug_mode()`.") 4175 fn_factory = trace_tf_function(defun_kwargs) 4176 4177 self._function = fn_factory() 4178 # There is no graph to add in eager mode. 4179 add_to_graph &= not context.executing_eagerly() 4180 # There are some lifetime issues when a legacy function is not added to a 4181 # out-living graph. It's already deprecated so de-prioritizing the fix. 4182 add_to_graph |= use_legacy_function 4183 if add_to_graph: 4184 self._function.add_to_graph(ops.get_default_graph()) 4185 4186 if not use_legacy_function: 4187 outer_graph_seed = ops.get_default_graph().seed 4188 if outer_graph_seed and self._function.graph.seed == outer_graph_seed: 4189 if self._function.graph._seed_used: 4190 warnings.warn( 4191 "Seed %s from outer graph might be getting used by function %s, " 4192 "if the random op has not been provided any seed. Explicitly set " 4193 "the seed in the function if this is not the intended behavior." 4194 %(outer_graph_seed, func_name), stacklevel=4) 4195 4196 @property 4197 def output_structure(self): 4198 return self._output_structure 4199 4200 @property 4201 def output_classes(self): 4202 return nest.map_structure( 4203 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 4204 self._output_structure) 4205 4206 @property 4207 def output_shapes(self): 4208 return nest.map_structure( 4209 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 4210 self._output_structure) 4211 4212 @property 4213 def output_types(self): 4214 return nest.map_structure( 4215 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 4216 self._output_structure) 4217 4218 @property 4219 def function(self): 4220 return self._function 4221 4222 4223class _GeneratorDataset(DatasetSource): 4224 """A `Dataset` that generates elements by invoking a function.""" 4225 4226 def __init__(self, init_args, init_func, next_func, finalize_func, 4227 output_signature): 4228 """Constructs a `_GeneratorDataset`. 4229 4230 Args: 4231 init_args: A (nested) structure representing the arguments to `init_func`. 4232 init_func: A TensorFlow function that will be called on `init_args` each 4233 time a C++ iterator over this dataset is constructed. Returns a (nested) 4234 structure representing the "state" of the dataset. 4235 next_func: A TensorFlow function that will be called on the result of 4236 `init_func` to produce each element, and that raises `OutOfRangeError` 4237 to terminate iteration. 4238 finalize_func: A TensorFlow function that will be called on the result of 4239 `init_func` immediately before a C++ iterator over this dataset is 4240 destroyed. The return value is ignored. 4241 output_signature: A (nested) structure of `tf.TypeSpec` objects describing 4242 the output of `next_func`. 4243 """ 4244 self._init_args = init_args 4245 4246 self._init_structure = structure.type_spec_from_value(init_args) 4247 4248 self._init_func = StructuredFunctionWrapper( 4249 init_func, 4250 self._transformation_name(), 4251 input_structure=self._init_structure) 4252 4253 self._next_func = StructuredFunctionWrapper( 4254 next_func, 4255 self._transformation_name(), 4256 input_structure=self._init_func.output_structure) 4257 4258 self._finalize_func = StructuredFunctionWrapper( 4259 finalize_func, 4260 self._transformation_name(), 4261 input_structure=self._init_func.output_structure) 4262 4263 self._output_signature = output_signature 4264 4265 variant_tensor = gen_dataset_ops.generator_dataset( 4266 structure.to_tensor_list(self._init_structure, self._init_args) + 4267 self._init_func.function.captured_inputs, 4268 self._next_func.function.captured_inputs, 4269 self._finalize_func.function.captured_inputs, 4270 init_func=self._init_func.function, 4271 next_func=self._next_func.function, 4272 finalize_func=self._finalize_func.function, 4273 **self._flat_structure) 4274 super(_GeneratorDataset, self).__init__(variant_tensor) 4275 4276 @property 4277 def element_spec(self): 4278 return self._output_signature 4279 4280 def _transformation_name(self): 4281 return "Dataset.from_generator()" 4282 4283 4284class ZipDataset(DatasetV2): 4285 """A `Dataset` that zips its inputs together.""" 4286 4287 def __init__(self, datasets): 4288 """See `Dataset.zip()` for details.""" 4289 for ds in nest.flatten(datasets): 4290 if not isinstance(ds, DatasetV2): 4291 if isinstance(ds, list): 4292 message = ("The argument to `Dataset.zip()` must be a (nested) " 4293 "structure of `Dataset` objects. Python `list` is not " 4294 "supported, please use a `tuple` instead.") 4295 else: 4296 message = ("The argument to `Dataset.zip()` must be a (nested) " 4297 "structure of `Dataset` objects.") 4298 raise TypeError(message) 4299 self._datasets = datasets 4300 self._structure = nest.pack_sequence_as( 4301 self._datasets, 4302 [ds.element_spec for ds in nest.flatten(self._datasets)]) 4303 variant_tensor = gen_dataset_ops.zip_dataset( 4304 [ds._variant_tensor for ds in nest.flatten(self._datasets)], 4305 **self._flat_structure) 4306 super(ZipDataset, self).__init__(variant_tensor) 4307 4308 def _inputs(self): 4309 return nest.flatten(self._datasets) 4310 4311 @property 4312 def element_spec(self): 4313 return self._structure 4314 4315 4316class ConcatenateDataset(DatasetV2): 4317 """A `Dataset` that concatenates its input with given dataset.""" 4318 4319 def __init__(self, input_dataset, dataset_to_concatenate): 4320 """See `Dataset.concatenate()` for details.""" 4321 self._input_dataset = input_dataset 4322 self._dataset_to_concatenate = dataset_to_concatenate 4323 4324 output_types = get_legacy_output_types(input_dataset) 4325 if output_types != get_legacy_output_types(dataset_to_concatenate): 4326 raise TypeError( 4327 "Two datasets to concatenate have different types %s and %s" % 4328 (output_types, get_legacy_output_types(dataset_to_concatenate))) 4329 4330 output_classes = get_legacy_output_classes(input_dataset) 4331 if output_classes != get_legacy_output_classes(dataset_to_concatenate): 4332 raise TypeError( 4333 "Two datasets to concatenate have different classes %s and %s" % 4334 (output_classes, get_legacy_output_classes(dataset_to_concatenate))) 4335 4336 spec1 = input_dataset.element_spec 4337 spec2 = dataset_to_concatenate.element_spec 4338 self._structure = nest.pack_sequence_as(spec1, [ 4339 ts1.most_specific_compatible_type(ts2) 4340 for (ts1, ts2) in zip(nest.flatten(spec1), nest.flatten(spec2)) 4341 ]) 4342 4343 self._input_datasets = [input_dataset, dataset_to_concatenate] 4344 # pylint: disable=protected-access 4345 variant_tensor = gen_dataset_ops.concatenate_dataset( 4346 input_dataset._variant_tensor, dataset_to_concatenate._variant_tensor, 4347 **self._flat_structure) 4348 # pylint: enable=protected-access 4349 super(ConcatenateDataset, self).__init__(variant_tensor) 4350 4351 def _inputs(self): 4352 return self._input_datasets 4353 4354 @property 4355 def element_spec(self): 4356 return self._structure 4357 4358 4359class RepeatDataset(UnaryUnchangedStructureDataset): 4360 """A `Dataset` that repeats its input several times.""" 4361 4362 def __init__(self, input_dataset, count): 4363 """See `Dataset.repeat()` for details.""" 4364 self._input_dataset = input_dataset 4365 if count is None: 4366 self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") 4367 else: 4368 self._count = ops.convert_to_tensor( 4369 count, dtype=dtypes.int64, name="count") 4370 variant_tensor = gen_dataset_ops.repeat_dataset( 4371 input_dataset._variant_tensor, # pylint: disable=protected-access 4372 count=self._count, 4373 **self._flat_structure) 4374 super(RepeatDataset, self).__init__(input_dataset, variant_tensor) 4375 4376 4377class RangeDataset(DatasetSource): 4378 """A `Dataset` of a step separated range of values.""" 4379 4380 def __init__(self, *args, **kwargs): 4381 """See `Dataset.range()` for details.""" 4382 self._parse_args(*args, **kwargs) 4383 self._structure = tensor_spec.TensorSpec([], self._output_type) 4384 variant_tensor = gen_dataset_ops.range_dataset( 4385 start=self._start, 4386 stop=self._stop, 4387 step=self._step, 4388 **self._flat_structure) 4389 super(RangeDataset, self).__init__(variant_tensor) 4390 4391 def _parse_args(self, *args, **kwargs): 4392 """Parse arguments according to the same rules as the `range()` builtin.""" 4393 if len(args) == 1: 4394 self._start = self._build_tensor(0, "start") 4395 self._stop = self._build_tensor(args[0], "stop") 4396 self._step = self._build_tensor(1, "step") 4397 elif len(args) == 2: 4398 self._start = self._build_tensor(args[0], "start") 4399 self._stop = self._build_tensor(args[1], "stop") 4400 self._step = self._build_tensor(1, "step") 4401 elif len(args) == 3: 4402 self._start = self._build_tensor(args[0], "start") 4403 self._stop = self._build_tensor(args[1], "stop") 4404 self._step = self._build_tensor(args[2], "step") 4405 else: 4406 raise ValueError("Invalid arguments to RangeDataset: %s" % str(args)) 4407 if "output_type" in kwargs: 4408 self._output_type = kwargs["output_type"] 4409 else: 4410 self._output_type = dtypes.int64 4411 4412 def _build_tensor(self, int64_value, name): 4413 return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name) 4414 4415 @property 4416 def element_spec(self): 4417 return self._structure 4418 4419 4420class CacheDataset(UnaryUnchangedStructureDataset): 4421 """A `Dataset` that caches elements of its input.""" 4422 4423 def __init__(self, input_dataset, filename): 4424 """See `Dataset.cache()` for details.""" 4425 self._input_dataset = input_dataset 4426 self._filename = ops.convert_to_tensor( 4427 filename, dtype=dtypes.string, name="filename") 4428 if tf2.enabled() and (context.executing_eagerly() or ops.inside_function()): 4429 variant_tensor = gen_dataset_ops.cache_dataset_v2( 4430 input_dataset._variant_tensor, # pylint: disable=protected-access 4431 filename=self._filename, 4432 cache=gen_dataset_ops.dummy_memory_cache(), 4433 **self._flat_structure) 4434 else: 4435 variant_tensor = gen_dataset_ops.cache_dataset( 4436 input_dataset._variant_tensor, # pylint: disable=protected-access 4437 filename=self._filename, 4438 **self._flat_structure) 4439 super(CacheDataset, self).__init__(input_dataset, variant_tensor) 4440 4441 4442class ShuffleDataset(UnaryUnchangedStructureDataset): 4443 """A `Dataset` that randomly shuffles the elements of its input.""" 4444 4445 def __init__(self, 4446 input_dataset, 4447 buffer_size, 4448 seed=None, 4449 reshuffle_each_iteration=None): 4450 """Randomly shuffles the elements of this dataset. 4451 4452 Args: 4453 input_dataset: The input dataset. 4454 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 4455 elements from this dataset from which the new dataset will sample. 4456 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 4457 seed that will be used to create the distribution. See 4458 `tf.random.set_seed` for behavior. 4459 reshuffle_each_iteration: (Optional.) A boolean, which if true indicates 4460 that the dataset should be pseudorandomly reshuffled each time it is 4461 iterated over. (Defaults to `True`.) 4462 4463 Returns: 4464 A `Dataset`. 4465 4466 Raises: 4467 ValueError: if invalid arguments are provided. 4468 """ 4469 self._input_dataset = input_dataset 4470 self._buffer_size = ops.convert_to_tensor( 4471 buffer_size, dtype=dtypes.int64, name="buffer_size") 4472 self._seed, self._seed2 = random_seed.get_seed(seed) 4473 if reshuffle_each_iteration is None: 4474 reshuffle_each_iteration = True 4475 self._reshuffle_each_iteration = reshuffle_each_iteration 4476 4477 if (tf2.enabled() and 4478 (context.executing_eagerly() or ops.inside_function())): 4479 variant_tensor = gen_dataset_ops.shuffle_dataset_v3( 4480 input_dataset._variant_tensor, # pylint: disable=protected-access 4481 buffer_size=self._buffer_size, 4482 seed=self._seed, 4483 seed2=self._seed2, 4484 seed_generator=gen_dataset_ops.dummy_seed_generator(), 4485 reshuffle_each_iteration=self._reshuffle_each_iteration, 4486 **self._flat_structure) 4487 else: 4488 variant_tensor = gen_dataset_ops.shuffle_dataset( 4489 input_dataset._variant_tensor, # pylint: disable=protected-access 4490 buffer_size=self._buffer_size, 4491 seed=self._seed, 4492 seed2=self._seed2, 4493 reshuffle_each_iteration=self._reshuffle_each_iteration, 4494 **self._flat_structure) 4495 super(ShuffleDataset, self).__init__(input_dataset, variant_tensor) 4496 4497 4498class TakeDataset(UnaryUnchangedStructureDataset): 4499 """A `Dataset` containing the first `count` elements from its input.""" 4500 4501 def __init__(self, input_dataset, count): 4502 """See `Dataset.take()` for details.""" 4503 self._input_dataset = input_dataset 4504 self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") 4505 variant_tensor = gen_dataset_ops.take_dataset( 4506 input_dataset._variant_tensor, # pylint: disable=protected-access 4507 count=self._count, 4508 **self._flat_structure) 4509 super(TakeDataset, self).__init__(input_dataset, variant_tensor) 4510 4511 4512class SkipDataset(UnaryUnchangedStructureDataset): 4513 """A `Dataset` skipping the first `count` elements from its input.""" 4514 4515 def __init__(self, input_dataset, count): 4516 """See `Dataset.skip()` for details.""" 4517 self._input_dataset = input_dataset 4518 self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") 4519 variant_tensor = gen_dataset_ops.skip_dataset( 4520 input_dataset._variant_tensor, # pylint: disable=protected-access 4521 count=self._count, 4522 **self._flat_structure) 4523 super(SkipDataset, self).__init__(input_dataset, variant_tensor) 4524 4525 4526class ShardDataset(UnaryUnchangedStructureDataset): 4527 """A `Dataset` for sharding its input.""" 4528 4529 def __init__(self, input_dataset, num_shards, index): 4530 """See `Dataset.shard()` for details.""" 4531 self._input_dataset = input_dataset 4532 self._num_shards = ops.convert_to_tensor( 4533 num_shards, dtype=dtypes.int64, name="num_shards") 4534 self._index = ops.convert_to_tensor(index, dtype=dtypes.int64, name="index") 4535 variant_tensor = gen_dataset_ops.shard_dataset( 4536 input_dataset._variant_tensor, # pylint: disable=protected-access 4537 num_shards=self._num_shards, 4538 index=self._index, 4539 **self._flat_structure) 4540 super(ShardDataset, self).__init__(input_dataset, variant_tensor) 4541 4542 4543class BatchDataset(UnaryDataset): 4544 """A `Dataset` that batches contiguous elements from its input.""" 4545 4546 def __init__(self, input_dataset, batch_size, drop_remainder): 4547 """See `Dataset.batch()` for details.""" 4548 self._input_dataset = input_dataset 4549 self._batch_size = ops.convert_to_tensor( 4550 batch_size, dtype=dtypes.int64, name="batch_size") 4551 self._drop_remainder = ops.convert_to_tensor( 4552 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 4553 4554 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder) 4555 # pylint: disable=protected-access 4556 if constant_drop_remainder: 4557 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) 4558 # or `False` (explicitly retaining the remainder). 4559 # pylint: disable=g-long-lambda 4560 constant_batch_size = tensor_util.constant_value(self._batch_size) 4561 self._structure = nest.map_structure( 4562 lambda component_spec: component_spec._batch(constant_batch_size), 4563 input_dataset.element_spec) 4564 else: 4565 self._structure = nest.map_structure( 4566 lambda component_spec: component_spec._batch(None), 4567 input_dataset.element_spec) 4568 variant_tensor = gen_dataset_ops.batch_dataset_v2( 4569 input_dataset._variant_tensor, 4570 batch_size=self._batch_size, 4571 drop_remainder=self._drop_remainder, 4572 **self._flat_structure) 4573 super(BatchDataset, self).__init__(input_dataset, variant_tensor) 4574 4575 @property 4576 def element_spec(self): 4577 return self._structure 4578 4579 4580class ParallelBatchDataset(UnaryDataset): 4581 """A `Dataset` that batches contiguous elements from its input in parallel.""" 4582 4583 def __init__(self, input_dataset, batch_size, drop_remainder, 4584 num_parallel_calls, deterministic): 4585 """See `Dataset.batch()` for details.""" 4586 self._input_dataset = input_dataset 4587 self._batch_size = ops.convert_to_tensor( 4588 batch_size, dtype=dtypes.int64, name="batch_size") 4589 self._drop_remainder = ops.convert_to_tensor( 4590 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 4591 self._num_parallel_calls = ops.convert_to_tensor( 4592 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 4593 if deterministic is None: 4594 self._deterministic = "default" 4595 elif deterministic: 4596 self._deterministic = "true" 4597 else: 4598 self._deterministic = "false" 4599 4600 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder) 4601 # pylint: disable=protected-access 4602 if constant_drop_remainder: 4603 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) 4604 # or `False` (explicitly retaining the remainder). 4605 # pylint: disable=g-long-lambda 4606 constant_batch_size = tensor_util.constant_value(self._batch_size) 4607 self._structure = nest.map_structure( 4608 lambda component_spec: component_spec._batch(constant_batch_size), 4609 input_dataset.element_spec) 4610 else: 4611 self._structure = nest.map_structure( 4612 lambda component_spec: component_spec._batch(None), 4613 input_dataset.element_spec) 4614 4615 variant_tensor = gen_dataset_ops.parallel_batch_dataset( 4616 input_dataset._variant_tensor, 4617 batch_size=self._batch_size, 4618 num_parallel_calls=self._num_parallel_calls, 4619 drop_remainder=self._drop_remainder, 4620 deterministic=self._deterministic, 4621 **self._flat_structure) 4622 4623 super(ParallelBatchDataset, self).__init__(input_dataset, variant_tensor) 4624 4625 @property 4626 def element_spec(self): 4627 return self._structure 4628 4629 4630class _NumpyIterator(object): 4631 """Iterator over a dataset with elements converted to numpy.""" 4632 4633 __slots__ = ["_iterator"] 4634 4635 def __init__(self, dataset): 4636 self._iterator = iter(dataset) 4637 4638 def __iter__(self): 4639 return self 4640 4641 def __next__(self): 4642 4643 def to_numpy(x): 4644 numpy = x._numpy() # pylint: disable=protected-access 4645 if isinstance(numpy, np.ndarray): 4646 # `numpy` shares the same underlying buffer as the `x` Tensor. 4647 # Tensors are expected to be immutable, so we disable writes. 4648 numpy.setflags(write=False) 4649 return numpy 4650 4651 return nest.map_structure(to_numpy, next(self._iterator)) 4652 4653 def next(self): 4654 return self.__next__() 4655 4656 4657class _VariantTracker(tracking.CapturableResource): 4658 """Allows export of functions capturing a Dataset in SavedModels. 4659 4660 When saving a SavedModel, `tf.saved_model.save` traverses the object 4661 graph. Since Datasets reference _VariantTracker objects, that traversal will 4662 find a _VariantTracker for each Dataset and so know how to save and restore 4663 functions which reference the Dataset's variant Tensor. 4664 """ 4665 4666 def __init__(self, variant_tensor, resource_creator): 4667 """Record that `variant_tensor` is associated with `resource_creator`. 4668 4669 Args: 4670 variant_tensor: The variant-dtype Tensor associated with the Dataset. This 4671 Tensor will be a captured input to functions which use the Dataset, and 4672 is used by saving code to identify the corresponding _VariantTracker. 4673 resource_creator: A zero-argument function which creates a new 4674 variant-dtype Tensor. This function will be included in SavedModels and 4675 run to re-create the Dataset's variant Tensor on restore. 4676 """ 4677 super(_VariantTracker, self).__init__(device="CPU") 4678 self._resource_handle = variant_tensor 4679 self._create_resource = resource_creator 4680 4681 4682def _is_padded_shape_compatible_with(padded_shape, input_component_shape): 4683 """Returns `True` if `input_component_shape` can be padded to `padded_shape`. 4684 4685 Args: 4686 padded_shape: A `tf.TensorShape`. 4687 input_component_shape: A `tf.TensorShape`. 4688 4689 Returns: 4690 `True` if `input_component_shape` can be padded to `padded_shape`, otherwise 4691 `False`. 4692 """ 4693 4694 if padded_shape.dims is None or input_component_shape.dims is None: 4695 return True 4696 if len(padded_shape.dims) != len(input_component_shape.dims): 4697 return False 4698 for padded_dim, input_dim in zip( 4699 padded_shape.dims, input_component_shape.dims): 4700 if (padded_dim.value is not None and input_dim.value is not None 4701 and padded_dim.value < input_dim.value): 4702 return False 4703 return True 4704 4705 4706def _padded_shape_to_tensor(padded_shape, input_component_shape): 4707 """Converts `padded_shape` to a `tf.Tensor` representing that shape. 4708 4709 Args: 4710 padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python 4711 sequence, or a 1-D `tf.Tensor` of `tf.int64` elements. 4712 input_component_shape: A `tf.TensorShape`, with which `padded_shape` must 4713 be compatible. 4714 4715 Returns: 4716 A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`. 4717 4718 Raises: 4719 ValueError: If `padded_shape` is not a shape or not compatible with 4720 `input_component_shape`. 4721 TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor. 4722 """ 4723 try: 4724 # Try to convert the `padded_shape` to a `tf.TensorShape` 4725 padded_shape_as_shape = tensor_shape.as_shape(padded_shape) 4726 # We will return the "canonical" tensor representation, which uses 4727 # `-1` in place of `None`. 4728 ret = ops.convert_to_tensor( 4729 [dim if dim is not None else -1 4730 for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64) 4731 except (TypeError, ValueError): 4732 # The argument was not trivially convertible to a 4733 # `tf.TensorShape`, so fall back on the conversion to tensor 4734 # machinery. 4735 ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64) 4736 if ret.shape.dims is not None and len(ret.shape.dims) != 1: 4737 six.reraise(ValueError, ValueError( 4738 "Padded shape %s must be a 1-D tensor of tf.int64 values, but its " 4739 "shape was %s." % (padded_shape, ret.shape)), sys.exc_info()[2]) 4740 if ret.dtype != dtypes.int64: 4741 six.reraise( 4742 TypeError, 4743 TypeError( 4744 "Padded shape %s must be a 1-D tensor of tf.int64 values, but " 4745 "its element type was %s." % (padded_shape, ret.dtype.name)), 4746 sys.exc_info()[2]) 4747 padded_shape_as_shape = tensor_util.constant_value_as_shape(ret) 4748 4749 if not _is_padded_shape_compatible_with(padded_shape_as_shape, 4750 input_component_shape): 4751 raise ValueError("The padded shape %s is not compatible with the " 4752 "corresponding input component shape %s." 4753 % (padded_shape_as_shape, input_component_shape)) 4754 4755 return ret 4756 4757 4758def _padding_value_to_tensor(value, output_type): 4759 """Converts the padding value to a tensor. 4760 4761 Args: 4762 value: The padding value. 4763 output_type: Its expected dtype. 4764 4765 Returns: 4766 A scalar `Tensor`. 4767 4768 Raises: 4769 ValueError: if the padding value is not a scalar. 4770 TypeError: if the padding value's type does not match `output_type`. 4771 """ 4772 value = ops.convert_to_tensor(value, name="padding_value") 4773 if not value.shape.is_compatible_with(tensor_shape.TensorShape([])): 4774 raise ValueError("Padding value should be a scalar, but is not: %s" % value) 4775 if value.dtype != output_type: 4776 raise TypeError("Padding value tensor (%s) does not match output type: %s" % 4777 (value, output_type)) 4778 return value 4779 4780 4781def _padding_values_or_default(padding_values, input_dataset): 4782 """Returns padding values with None elements replaced with default values.""" 4783 4784 def make_zero(t): 4785 if t.base_dtype == dtypes.string: 4786 return "" 4787 elif t.base_dtype == dtypes.variant: 4788 error_msg = ("Unable to create padding for field of type 'variant' " 4789 "because t.base_type == dtypes.variant == " 4790 "{}.".format(t.base_dtype)) 4791 raise TypeError(error_msg) 4792 elif t.base_dtype == dtypes.bfloat16: 4793 # Special case `bfloat16` because it is not supported by NumPy. 4794 return constant_op.constant(0, dtype=dtypes.bfloat16) 4795 else: 4796 return np.zeros_like(t.as_numpy_dtype()) 4797 4798 def value_or_default(value, default): 4799 return default if value is None else value 4800 4801 default_padding = nest.map_structure( 4802 make_zero, 4803 get_legacy_output_types(input_dataset)) 4804 return nest.map_structure_up_to(padding_values, value_or_default, 4805 padding_values, default_padding) 4806 4807 4808class PaddedBatchDataset(UnaryDataset): 4809 """A `Dataset` that batches and pads contiguous elements from its input.""" 4810 4811 def __init__(self, input_dataset, batch_size, padded_shapes, padding_values, 4812 drop_remainder): 4813 """See `Dataset.batch()` for details.""" 4814 self._input_dataset = input_dataset 4815 4816 def check_types(component_spec): 4817 if not isinstance(component_spec, tensor_spec.TensorSpec): 4818 raise TypeError("Padded batching of components of type ", 4819 type(component_spec), " is not supported.") 4820 4821 nest.map_structure(check_types, input_dataset.element_spec) 4822 self._input_dataset = input_dataset 4823 self._batch_size = ops.convert_to_tensor( 4824 batch_size, dtype=dtypes.int64, name="batch_size") 4825 padding_values = _padding_values_or_default(padding_values, input_dataset) 4826 4827 input_shapes = get_legacy_output_shapes(input_dataset) 4828 flat_padded_shapes = nest.flatten_up_to(input_shapes, padded_shapes) 4829 4830 flat_padded_shapes_as_tensors = [] 4831 4832 for input_component_shape, padded_shape in zip( 4833 nest.flatten(input_shapes), flat_padded_shapes): 4834 flat_padded_shapes_as_tensors.append( 4835 _padded_shape_to_tensor(padded_shape, input_component_shape)) 4836 4837 self._padded_shapes = nest.pack_sequence_as(input_shapes, 4838 flat_padded_shapes_as_tensors) 4839 4840 # If padding_values is a single element and input_shapes is a structure, 4841 # "broadcast" padding_values to the same structure as input_shapes. 4842 if nest.is_sequence(input_shapes) and not nest.is_sequence(padding_values): 4843 padding_values = nest.map_structure(lambda _: padding_values, 4844 input_shapes) 4845 4846 self._padding_values = nest.map_structure_up_to( 4847 input_shapes, _padding_value_to_tensor, padding_values, 4848 get_legacy_output_types(input_dataset)) 4849 self._drop_remainder = ops.convert_to_tensor( 4850 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 4851 4852 def _padded_shape_to_batch_shape(s): 4853 return tensor_shape.TensorShape([ 4854 tensor_util.constant_value(self._batch_size) 4855 if smart_cond.smart_constant_value(self._drop_remainder) else None 4856 ]).concatenate(tensor_util.constant_value_as_shape(s)) 4857 4858 output_shapes = nest.map_structure( 4859 _padded_shape_to_batch_shape, self._padded_shapes) 4860 self._structure = structure.convert_legacy_structure( 4861 get_legacy_output_types(self._input_dataset), output_shapes, 4862 get_legacy_output_classes(self._input_dataset)) 4863 4864 # pylint: disable=protected-access 4865 # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018. 4866 if smart_cond.smart_constant_value(self._drop_remainder) is False: 4867 variant_tensor = gen_dataset_ops.padded_batch_dataset( 4868 input_dataset._variant_tensor, # pylint: disable=protected-access 4869 batch_size=self._batch_size, 4870 padded_shapes=[ 4871 ops.convert_to_tensor(s, dtype=dtypes.int64) 4872 for s in nest.flatten(self._padded_shapes) 4873 ], 4874 padding_values=nest.flatten(self._padding_values), 4875 output_shapes=structure.get_flat_tensor_shapes(self._structure)) 4876 else: 4877 variant_tensor = gen_dataset_ops.padded_batch_dataset_v2( 4878 input_dataset._variant_tensor, # pylint: disable=protected-access 4879 batch_size=self._batch_size, 4880 padded_shapes=[ 4881 ops.convert_to_tensor(s, dtype=dtypes.int64) 4882 for s in nest.flatten(self._padded_shapes) 4883 ], 4884 padding_values=nest.flatten(self._padding_values), 4885 drop_remainder=self._drop_remainder, 4886 output_shapes=structure.get_flat_tensor_shapes(self._structure)) 4887 super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor) 4888 4889 @property 4890 def element_spec(self): 4891 return self._structure 4892 4893 4894def _should_pack(arg): 4895 """Determines whether the caller needs to pack the argument in a tuple. 4896 4897 If user-defined function returns a list of tensors, `nest.flatten()` and 4898 `ops.convert_to_tensor()` and would conspire to attempt to stack those tensors 4899 into a single tensor because the tf.data version of `nest.flatten()` does 4900 not recurse into lists. Since it is more likely that the list arose from 4901 returning the result of an operation (such as `tf.numpy_function()`) that 4902 returns a list of not-necessarily-stackable tensors, we treat the returned 4903 value as a `tuple` instead. A user wishing to pack the return value into a 4904 single tensor can use an explicit `tf.stack()` before returning. 4905 4906 Args: 4907 arg: argument to check 4908 4909 Returns: 4910 Indication of whether the caller needs to pack the argument in a tuple. 4911 """ 4912 return isinstance(arg, list) 4913 4914 4915def _should_unpack(arg): 4916 """Determines whether the caller needs to unpack the argument from a tuple. 4917 4918 Args: 4919 arg: argument to check 4920 4921 Returns: 4922 Indication of whether the caller needs to unpack the argument from a tuple. 4923 """ 4924 return type(arg) is tuple # pylint: disable=unidiomatic-typecheck 4925 4926 4927class MapDataset(UnaryDataset): 4928 """A `Dataset` that maps a function over elements in its input.""" 4929 4930 def __init__(self, 4931 input_dataset, 4932 map_func, 4933 use_inter_op_parallelism=True, 4934 preserve_cardinality=False, 4935 use_legacy_function=False): 4936 """See `Dataset.map()` for details.""" 4937 self._input_dataset = input_dataset 4938 self._use_inter_op_parallelism = use_inter_op_parallelism 4939 self._preserve_cardinality = preserve_cardinality 4940 self._map_func = StructuredFunctionWrapper( 4941 map_func, 4942 self._transformation_name(), 4943 dataset=input_dataset, 4944 use_legacy_function=use_legacy_function) 4945 variant_tensor = gen_dataset_ops.map_dataset( 4946 input_dataset._variant_tensor, # pylint: disable=protected-access 4947 self._map_func.function.captured_inputs, 4948 f=self._map_func.function, 4949 use_inter_op_parallelism=self._use_inter_op_parallelism, 4950 preserve_cardinality=self._preserve_cardinality, 4951 **self._flat_structure) 4952 super(MapDataset, self).__init__(input_dataset, variant_tensor) 4953 4954 def _functions(self): 4955 return [self._map_func] 4956 4957 @property 4958 def element_spec(self): 4959 return self._map_func.output_structure 4960 4961 def _transformation_name(self): 4962 return "Dataset.map()" 4963 4964 4965class ParallelMapDataset(UnaryDataset): 4966 """A `Dataset` that maps a function over elements in its input in parallel.""" 4967 4968 def __init__(self, 4969 input_dataset, 4970 map_func, 4971 num_parallel_calls, 4972 deterministic, 4973 use_inter_op_parallelism=True, 4974 preserve_cardinality=False, 4975 use_legacy_function=False): 4976 """See `Dataset.map()` for details.""" 4977 self._input_dataset = input_dataset 4978 self._use_inter_op_parallelism = use_inter_op_parallelism 4979 self._map_func = StructuredFunctionWrapper( 4980 map_func, 4981 self._transformation_name(), 4982 dataset=input_dataset, 4983 use_legacy_function=use_legacy_function) 4984 if deterministic is None: 4985 self._deterministic = "default" 4986 elif deterministic: 4987 self._deterministic = "true" 4988 else: 4989 self._deterministic = "false" 4990 self._preserve_cardinality = preserve_cardinality 4991 self._num_parallel_calls = ops.convert_to_tensor( 4992 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 4993 variant_tensor = gen_dataset_ops.parallel_map_dataset_v2( 4994 input_dataset._variant_tensor, # pylint: disable=protected-access 4995 self._map_func.function.captured_inputs, 4996 f=self._map_func.function, 4997 num_parallel_calls=self._num_parallel_calls, 4998 deterministic=self._deterministic, 4999 use_inter_op_parallelism=self._use_inter_op_parallelism, 5000 preserve_cardinality=self._preserve_cardinality, 5001 **self._flat_structure) 5002 super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor) 5003 5004 def _functions(self): 5005 return [self._map_func] 5006 5007 @property 5008 def element_spec(self): 5009 return self._map_func.output_structure 5010 5011 def _transformation_name(self): 5012 return "Dataset.map()" 5013 5014 5015class FlatMapDataset(UnaryDataset): 5016 """A `Dataset` that maps a function over its input and flattens the result.""" 5017 5018 def __init__(self, input_dataset, map_func): 5019 """See `Dataset.flat_map()` for details.""" 5020 self._input_dataset = input_dataset 5021 self._map_func = StructuredFunctionWrapper( 5022 map_func, self._transformation_name(), dataset=input_dataset) 5023 if not isinstance(self._map_func.output_structure, DatasetSpec): 5024 raise TypeError( 5025 "`map_func` must return a `Dataset` object. Got {}".format( 5026 type(self._map_func.output_structure))) 5027 self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access 5028 variant_tensor = gen_dataset_ops.flat_map_dataset( 5029 input_dataset._variant_tensor, # pylint: disable=protected-access 5030 self._map_func.function.captured_inputs, 5031 f=self._map_func.function, 5032 **self._flat_structure) 5033 super(FlatMapDataset, self).__init__(input_dataset, variant_tensor) 5034 5035 def _functions(self): 5036 return [self._map_func] 5037 5038 @property 5039 def element_spec(self): 5040 return self._structure 5041 5042 def _transformation_name(self): 5043 return "Dataset.flat_map()" 5044 5045 5046class InterleaveDataset(UnaryDataset): 5047 """A `Dataset` that interleaves the result of transformed inputs.""" 5048 5049 def __init__(self, input_dataset, map_func, cycle_length, block_length): 5050 """See `Dataset.interleave()` for details.""" 5051 5052 self._input_dataset = input_dataset 5053 self._map_func = StructuredFunctionWrapper( 5054 map_func, self._transformation_name(), dataset=input_dataset) 5055 if not isinstance(self._map_func.output_structure, DatasetSpec): 5056 raise TypeError( 5057 "`map_func` must return a `Dataset` object. Got {}".format( 5058 type(self._map_func.output_structure))) 5059 self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access 5060 self._cycle_length = ops.convert_to_tensor( 5061 cycle_length, dtype=dtypes.int64, name="cycle_length") 5062 self._block_length = ops.convert_to_tensor( 5063 block_length, dtype=dtypes.int64, name="block_length") 5064 5065 variant_tensor = gen_dataset_ops.interleave_dataset( 5066 input_dataset._variant_tensor, # pylint: disable=protected-access 5067 self._map_func.function.captured_inputs, # pylint: disable=protected-access 5068 self._cycle_length, 5069 self._block_length, 5070 f=self._map_func.function, 5071 **self._flat_structure) 5072 super(InterleaveDataset, self).__init__(input_dataset, variant_tensor) 5073 5074 def _functions(self): 5075 return [self._map_func] 5076 5077 @property 5078 def element_spec(self): 5079 return self._structure 5080 5081 def _transformation_name(self): 5082 return "Dataset.interleave()" 5083 5084 5085class ParallelInterleaveDataset(UnaryDataset): 5086 """A `Dataset` that maps a function over its input and interleaves the result.""" 5087 5088 def __init__(self, 5089 input_dataset, 5090 map_func, 5091 cycle_length, 5092 block_length, 5093 num_parallel_calls, 5094 buffer_output_elements=AUTOTUNE, 5095 prefetch_input_elements=AUTOTUNE, 5096 deterministic=None): 5097 """See `Dataset.interleave()` for details.""" 5098 self._input_dataset = input_dataset 5099 self._map_func = StructuredFunctionWrapper( 5100 map_func, self._transformation_name(), dataset=input_dataset) 5101 if not isinstance(self._map_func.output_structure, DatasetSpec): 5102 raise TypeError( 5103 "`map_func` must return a `Dataset` object. Got {}".format( 5104 type(self._map_func.output_structure))) 5105 self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access 5106 self._cycle_length = ops.convert_to_tensor( 5107 cycle_length, dtype=dtypes.int64, name="cycle_length") 5108 self._block_length = ops.convert_to_tensor( 5109 block_length, dtype=dtypes.int64, name="block_length") 5110 self._buffer_output_elements = ops.convert_to_tensor( 5111 buffer_output_elements, 5112 dtype=dtypes.int64, 5113 name="buffer_output_elements") 5114 self._prefetch_input_elements = ops.convert_to_tensor( 5115 prefetch_input_elements, 5116 dtype=dtypes.int64, 5117 name="prefetch_input_elements") 5118 5119 self._num_parallel_calls = ops.convert_to_tensor( 5120 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 5121 if deterministic is None: 5122 deterministic_string = "default" 5123 elif deterministic: 5124 deterministic_string = "true" 5125 else: 5126 deterministic_string = "false" 5127 5128 variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4( 5129 input_dataset._variant_tensor, # pylint: disable=protected-access 5130 self._map_func.function.captured_inputs, # pylint: disable=protected-access 5131 self._cycle_length, 5132 self._block_length, 5133 self._buffer_output_elements, 5134 self._prefetch_input_elements, 5135 self._num_parallel_calls, 5136 f=self._map_func.function, 5137 deterministic=deterministic_string, 5138 **self._flat_structure) 5139 super(ParallelInterleaveDataset, self).__init__(input_dataset, 5140 variant_tensor) 5141 5142 def _functions(self): 5143 return [self._map_func] 5144 5145 @property 5146 def element_spec(self): 5147 return self._structure 5148 5149 def _transformation_name(self): 5150 return "Dataset.interleave()" 5151 5152 5153class FilterDataset(UnaryUnchangedStructureDataset): 5154 """A `Dataset` that filters its input according to a predicate function.""" 5155 5156 def __init__(self, input_dataset, predicate, use_legacy_function=False): 5157 """See `Dataset.filter()` for details.""" 5158 self._input_dataset = input_dataset 5159 wrapped_func = StructuredFunctionWrapper( 5160 predicate, 5161 self._transformation_name(), 5162 dataset=input_dataset, 5163 use_legacy_function=use_legacy_function) 5164 if not wrapped_func.output_structure.is_compatible_with( 5165 tensor_spec.TensorSpec([], dtypes.bool)): 5166 error_msg = ("`predicate` return type must be convertible to a scalar " 5167 "boolean tensor. Was {}.").format( 5168 wrapped_func.output_structure) 5169 raise ValueError(error_msg) 5170 self._predicate = wrapped_func 5171 variant_tensor = gen_dataset_ops.filter_dataset( 5172 input_dataset._variant_tensor, # pylint: disable=protected-access 5173 other_arguments=self._predicate.function.captured_inputs, 5174 predicate=self._predicate.function, 5175 **self._flat_structure) 5176 super(FilterDataset, self).__init__(input_dataset, variant_tensor) 5177 5178 def _functions(self): 5179 return [self._predicate] 5180 5181 def _transformation_name(self): 5182 return "Dataset.filter()" 5183 5184 5185class PrefetchDataset(UnaryUnchangedStructureDataset): 5186 """A `Dataset` that asynchronously prefetches its input.""" 5187 5188 def __init__(self, input_dataset, buffer_size, slack_period=None): 5189 """See `Dataset.prefetch()` for details. 5190 5191 Args: 5192 input_dataset: The input dataset. 5193 buffer_size: See `Dataset.prefetch()` for details. 5194 slack_period: (Optional.) An integer. If non-zero, determines the number 5195 of GetNext calls before injecting slack into the execution. This may 5196 reduce CPU contention at the start of a step. Note that a tensorflow 5197 user should not have to set this manually; enable this behavior 5198 automatically via `tf.data.Options.experimental_slack` instead. Defaults 5199 to None. 5200 """ 5201 self._input_dataset = input_dataset 5202 if buffer_size is None: 5203 buffer_size = AUTOTUNE 5204 self._buffer_size = ops.convert_to_tensor( 5205 buffer_size, dtype=dtypes.int64, name="buffer_size") 5206 # pylint: disable=protected-access 5207 # We colocate the prefetch dataset with its input as this collocation only 5208 # happens automatically in graph mode. 5209 with ops.colocate_with(input_dataset._variant_tensor): 5210 variant_tensor = gen_dataset_ops.prefetch_dataset( 5211 input_dataset._variant_tensor, 5212 buffer_size=self._buffer_size, 5213 slack_period=slack_period, 5214 **self._flat_structure) 5215 super(PrefetchDataset, self).__init__(input_dataset, variant_tensor) 5216 5217 5218class WindowDataset(UnaryDataset): 5219 """A dataset that creates window datasets from the input elements.""" 5220 5221 def __init__(self, input_dataset, size, shift, stride, drop_remainder): 5222 """See `window_dataset()` for more details.""" 5223 self._input_dataset = input_dataset 5224 self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size") 5225 self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift") 5226 self._stride = ops.convert_to_tensor( 5227 stride, dtype=dtypes.int64, name="stride") 5228 self._drop_remainder = ops.convert_to_tensor( 5229 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 5230 self._structure = nest.pack_sequence_as( 5231 get_legacy_output_classes(input_dataset), [ 5232 DatasetSpec( # pylint: disable=g-complex-comprehension 5233 structure.convert_legacy_structure( 5234 output_type, output_shape, output_class)) 5235 for output_class, output_shape, output_type in zip( 5236 nest.flatten(get_legacy_output_classes(input_dataset)), 5237 nest.flatten(get_legacy_output_shapes(input_dataset)), 5238 nest.flatten(get_legacy_output_types(input_dataset))) 5239 ]) 5240 variant_tensor = gen_dataset_ops.window_dataset( 5241 input_dataset._variant_tensor, # pylint: disable=protected-access 5242 self._size, 5243 self._shift, 5244 self._stride, 5245 self._drop_remainder, 5246 **self._flat_structure) 5247 super(WindowDataset, self).__init__(input_dataset, variant_tensor) 5248 5249 @property 5250 def element_spec(self): 5251 return self._structure 5252 5253 5254class _OptionsDataset(UnaryUnchangedStructureDataset): 5255 """An identity `Dataset` that stores options.""" 5256 5257 def __init__(self, input_dataset, options): 5258 # pylint: disable=protected-access 5259 self._input_dataset = input_dataset 5260 options_pb = dataset_options_pb2.Options() 5261 options_pb.CopyFrom(options._to_proto()) 5262 with ops.colocate_with(input_dataset._variant_tensor): 5263 variant_tensor = gen_dataset_ops.options_dataset( 5264 input_dataset._variant_tensor, 5265 options_pb.SerializeToString(), **self._flat_structure) 5266 super(_OptionsDataset, self).__init__(input_dataset, variant_tensor) 5267 5268 if self._options_attr: 5269 self._options_attr._set_mutable(True) 5270 self._options_attr = self._options_attr.merge(options) 5271 else: 5272 self._options_attr = options 5273 self._options_attr._set_mutable(False) 5274 5275 5276def normalize_to_dense(dataset): 5277 """Normalizes non-tensor components in a dataset to dense representations. 5278 5279 This is necessary for dataset transformations that slice along the batch 5280 dimension and are oblivious to non-tensors, e.g. `unbatch`, `rebatch`. 5281 5282 Args: 5283 dataset: Dataset to normalize. 5284 5285 Returns: 5286 A dataset whose sparse and ragged tensors have been normalized to their 5287 dense representations. 5288 """ 5289 5290 # NOTE(mrry): This leads to a somewhat inefficient re-encoding step for all 5291 # non-tensor components. 5292 # 5293 # TODO(mrry): Consider optimizing this if it turns out to be a bottleneck. 5294 if _should_unpack(dataset.element_spec): 5295 5296 def normalize(*args): 5297 return structure.to_batched_tensor_list(dataset.element_spec, tuple(args)) 5298 else: 5299 def normalize(arg): 5300 return structure.to_batched_tensor_list(dataset.element_spec, arg) 5301 5302 normalized_dataset = dataset.map(normalize) 5303 5304 # NOTE(mrry): Our `map()` has lost information about the structure of 5305 # non-tensor components, so re-apply the structure of the original dataset. 5306 return _RestructuredDataset(normalized_dataset, dataset.element_spec) 5307 5308 5309class _RestructuredDataset(UnaryDataset): 5310 """An internal helper for changing the element spec of a dataset.""" 5311 5312 def __init__(self, dataset, structure): 5313 self._input_dataset = dataset 5314 self._structure = structure 5315 5316 variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access 5317 super(_RestructuredDataset, self).__init__(dataset, variant_tensor) 5318 5319 @property 5320 def element_spec(self): 5321 return self._structure 5322 5323 5324class _UnbatchDataset(UnaryDataset): 5325 """A dataset that splits the elements of its input into multiple elements.""" 5326 5327 def __init__(self, input_dataset): 5328 """See `unbatch()` for more details.""" 5329 flat_shapes = input_dataset._flat_shapes # pylint: disable=protected-access 5330 if any(s.ndims == 0 for s in flat_shapes): 5331 raise ValueError("Cannot unbatch an input with scalar components.") 5332 known_batch_dim = tensor_shape.Dimension(None) 5333 for s in flat_shapes: 5334 try: 5335 known_batch_dim = known_batch_dim.merge_with(s[0]) 5336 except ValueError: 5337 raise ValueError("Cannot unbatch an input whose components have " 5338 "different batch sizes.") 5339 self._input_dataset = input_dataset 5340 self._structure = nest.map_structure( 5341 lambda component_spec: component_spec._unbatch(), # pylint: disable=protected-access 5342 get_structure(input_dataset)) 5343 variant_tensor = ged_ops.unbatch_dataset( 5344 self._input_dataset._variant_tensor, # pylint: disable=protected-access 5345 **self._flat_structure) 5346 super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor) 5347 5348 @property 5349 def element_spec(self): 5350 return self._structure 5351 5352 5353class _GroupByWindowDataset(UnaryDataset): 5354 """A `Dataset` that groups its input and performs a windowed reduction.""" 5355 5356 def __init__(self, input_dataset, key_func, reduce_func, window_size_func): 5357 """See `group_by_window()` for details.""" 5358 self._input_dataset = input_dataset 5359 self._make_key_func(key_func, input_dataset) 5360 self._make_reduce_func(reduce_func, input_dataset) 5361 self._make_window_size_func(window_size_func) 5362 variant_tensor = ged_ops.group_by_window_dataset( 5363 self._input_dataset._variant_tensor, # pylint: disable=protected-access 5364 self._key_func.function.captured_inputs, 5365 self._reduce_func.function.captured_inputs, 5366 self._window_size_func.function.captured_inputs, 5367 key_func=self._key_func.function, 5368 reduce_func=self._reduce_func.function, 5369 window_size_func=self._window_size_func.function, 5370 **self._flat_structure) 5371 super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor) 5372 5373 def _make_window_size_func(self, window_size_func): 5374 """Make wrapping defun for window_size_func.""" 5375 5376 def window_size_func_wrapper(key): 5377 return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64) 5378 5379 self._window_size_func = StructuredFunctionWrapper( 5380 window_size_func_wrapper, 5381 self._transformation_name(), 5382 input_structure=tensor_spec.TensorSpec([], dtypes.int64)) 5383 if not self._window_size_func.output_structure.is_compatible_with( 5384 tensor_spec.TensorSpec([], dtypes.int64)): 5385 raise ValueError( 5386 "`window_size_func` must return a single tf.int64 scalar tensor.") 5387 5388 def _make_key_func(self, key_func, input_dataset): 5389 """Make wrapping defun for key_func.""" 5390 5391 def key_func_wrapper(*args): 5392 return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64) 5393 5394 self._key_func = StructuredFunctionWrapper( 5395 key_func_wrapper, self._transformation_name(), dataset=input_dataset) 5396 if not self._key_func.output_structure.is_compatible_with( 5397 tensor_spec.TensorSpec([], dtypes.int64)): 5398 raise ValueError( 5399 "`key_func` must return a single tf.int64 scalar tensor.") 5400 5401 def _make_reduce_func(self, reduce_func, input_dataset): 5402 """Make wrapping defun for reduce_func.""" 5403 nested_dataset = DatasetSpec(input_dataset.element_spec) 5404 input_structure = (tensor_spec.TensorSpec([], dtypes.int64), nested_dataset) 5405 self._reduce_func = StructuredFunctionWrapper( 5406 reduce_func, 5407 self._transformation_name(), 5408 input_structure=input_structure) 5409 if not isinstance(self._reduce_func.output_structure, DatasetSpec): 5410 raise TypeError("`reduce_func` must return a `Dataset` object.") 5411 # pylint: disable=protected-access 5412 self._element_spec = (self._reduce_func.output_structure._element_spec) 5413 5414 @property 5415 def element_spec(self): 5416 return self._element_spec 5417 5418 def _functions(self): 5419 return [self._key_func, self._reduce_func, self._window_size_func] 5420 5421 def _transformation_name(self): 5422 return "Dataset.group_by_window()" 5423 5424 5425class RandomDataset(DatasetSource): 5426 """A `Dataset` of pseudorandom values.""" 5427 5428 def __init__(self, seed=None): 5429 """A `Dataset` of pseudorandom values.""" 5430 self._seed, self._seed2 = random_seed.get_seed(seed) 5431 variant_tensor = ged_ops.random_dataset( 5432 seed=self._seed, seed2=self._seed2, **self._flat_structure) 5433 super(RandomDataset, self).__init__(variant_tensor) 5434 5435 @property 5436 def element_spec(self): 5437 return tensor_spec.TensorSpec([], dtypes.int64) 5438 5439 5440def _get_prob_original_static(initial_dist_t, target_dist_t): 5441 """Returns the static probability of sampling from the original. 5442 5443 `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters 5444 an Op that it isn't defined for. We have some custom logic to avoid this. 5445 5446 Args: 5447 initial_dist_t: A tensor of the initial distribution. 5448 target_dist_t: A tensor of the target distribution. 5449 5450 Returns: 5451 The probability of sampling from the original distribution as a constant, 5452 if it is a constant, or `None`. 5453 """ 5454 init_static = tensor_util.constant_value(initial_dist_t) 5455 target_static = tensor_util.constant_value(target_dist_t) 5456 5457 if init_static is None or target_static is None: 5458 return None 5459 else: 5460 return np.min(target_static / init_static) 5461 5462 5463def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_func, seed): 5464 """Filters a dataset based on per-class acceptance probabilities. 5465 5466 Args: 5467 dataset: The dataset to be filtered. 5468 acceptance_dist_ds: A dataset of acceptance probabilities. 5469 initial_dist_ds: A dataset of the initial probability distribution, given or 5470 estimated. 5471 class_func: A function mapping an element of the input dataset to a scalar 5472 `tf.int32` tensor. Values should be in `[0, num_classes)`. 5473 seed: (Optional.) Python integer seed for the resampler. 5474 5475 Returns: 5476 A dataset of (class value, data) after filtering. 5477 """ 5478 5479 def maybe_warn_on_large_rejection(accept_dist, initial_dist): 5480 proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist) 5481 return control_flow_ops.cond( 5482 math_ops.less(proportion_rejected, .5), 5483 lambda: accept_dist, 5484 lambda: logging_ops.Print( # pylint: disable=g-long-lambda 5485 accept_dist, [proportion_rejected, initial_dist, accept_dist], 5486 message="Proportion of examples rejected by sampler is high: ", 5487 summarize=100, 5488 first_n=10)) 5489 5490 acceptance_dist_ds = ( 5491 DatasetV2.zip((acceptance_dist_ds, 5492 initial_dist_ds)).map(maybe_warn_on_large_rejection)) 5493 5494 def _gather_and_copy(acceptance_prob, data): 5495 if isinstance(data, tuple): 5496 class_val = class_func(*data) 5497 else: 5498 class_val = class_func(data) 5499 return class_val, array_ops.gather(acceptance_prob, class_val), data 5500 5501 current_probabilities_and_class_and_data_ds = DatasetV2.zip( 5502 (acceptance_dist_ds, dataset)).map(_gather_and_copy) 5503 5504 def _reject(unused_class_val, p, unused_data): 5505 return random_ops.random_uniform([], seed=seed, dtype=p.dtype) < p 5506 5507 filtered_ds = current_probabilities_and_class_and_data_ds.filter(_reject) 5508 return filtered_ds.map(lambda class_value, _, data: (class_value, data)) 5509 5510 5511# pylint: disable=missing-function-docstring 5512def _estimate_initial_dist_ds(target_dist_t, 5513 class_values_ds, 5514 dist_estimation_batch_size=32, 5515 smoothing_constant=10): 5516 num_classes = (target_dist_t.shape[0] or array_ops.shape(target_dist_t)[0]) 5517 initial_examples_per_class_seen = array_ops.fill([num_classes], 5518 np.int64(smoothing_constant)) 5519 5520 def update_estimate_and_tile(num_examples_per_class_seen, c): 5521 updated_examples_per_class_seen, dist = _estimate_data_distribution( 5522 c, num_examples_per_class_seen) 5523 tiled_dist = array_ops.tile( 5524 array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) 5525 return updated_examples_per_class_seen, tiled_dist 5526 5527 initial_dist_ds = ( 5528 class_values_ds.batch(dist_estimation_batch_size).scan( 5529 initial_examples_per_class_seen, update_estimate_and_tile).unbatch()) 5530 5531 return initial_dist_ds 5532 5533 5534def _get_target_to_initial_ratio(initial_probs, target_probs): 5535 # Add tiny to initial_probs to avoid divide by zero. 5536 denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) 5537 return target_probs / denom 5538 5539 5540def _estimate_data_distribution(c, num_examples_per_class_seen): 5541 """Estimate data distribution as labels are seen. 5542 5543 Args: 5544 c: The class labels. Type `int32`, shape `[batch_size]`. 5545 num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, containing 5546 counts. 5547 5548 Returns: 5549 num_examples_per_lass_seen: Updated counts. Type `int64`, shape 5550 `[num_classes]`. 5551 dist: The updated distribution. Type `float32`, shape `[num_classes]`. 5552 """ 5553 num_classes = num_examples_per_class_seen.get_shape()[0] 5554 # Update the class-count based on what labels are seen in batch. 5555 num_examples_per_class_seen = math_ops.add( 5556 num_examples_per_class_seen, 5557 math_ops.reduce_sum( 5558 array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) 5559 init_prob_estimate = math_ops.truediv( 5560 num_examples_per_class_seen, 5561 math_ops.reduce_sum(num_examples_per_class_seen)) 5562 dist = math_ops.cast(init_prob_estimate, dtypes.float32) 5563 return num_examples_per_class_seen, dist 5564 5565 5566def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): 5567 """Calculates the acceptance probabilities and mixing ratio. 5568 5569 In this case, we assume that we can *either* sample from the original data 5570 distribution with probability `m`, or sample from a reshaped distribution 5571 that comes from rejection sampling on the original distribution. This 5572 rejection sampling is done on a per-class basis, with `a_i` representing the 5573 probability of accepting data from class `i`. 5574 5575 This method is based on solving the following analysis for the reshaped 5576 distribution: 5577 5578 Let F be the probability of a rejection (on any example). 5579 Let p_i be the proportion of examples in the data in class i (init_probs) 5580 Let a_i is the rate the rejection sampler should *accept* class i 5581 Let t_i is the target proportion in the minibatches for class i (target_probs) 5582 5583 ``` 5584 F = sum_i(p_i * (1-a_i)) 5585 = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1 5586 ``` 5587 5588 An example with class `i` will be accepted if `k` rejections occur, then an 5589 example with class `i` is seen by the rejector, and it is accepted. This can 5590 be written as follows: 5591 5592 ``` 5593 t_i = sum_k=0^inf(F^k * p_i * a_i) 5594 = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1 5595 = p_i * a_i / sum_j(p_j * a_j) using F from above 5596 ``` 5597 5598 Note that the following constraints hold: 5599 ``` 5600 0 <= p_i <= 1, sum_i(p_i) = 1 5601 0 <= a_i <= 1 5602 0 <= t_i <= 1, sum_i(t_i) = 1 5603 ``` 5604 5605 A solution for a_i in terms of the other variables is the following: 5606 ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` 5607 5608 If we try to minimize the amount of data rejected, we get the following: 5609 5610 M_max = max_i [ t_i / p_i ] 5611 M_min = min_i [ t_i / p_i ] 5612 5613 The desired probability of accepting data if it comes from class `i`: 5614 5615 a_i = (t_i/p_i - m) / (M_max - m) 5616 5617 The desired probability of pulling a data element from the original dataset, 5618 rather than the filtered one: 5619 5620 m = M_min 5621 5622 Args: 5623 initial_probs: A Tensor of the initial probability distribution, given or 5624 estimated. 5625 target_probs: A Tensor of the corresponding classes. 5626 5627 Returns: 5628 (A 1D Tensor with the per-class acceptance probabilities, the desired 5629 probability of pull from the original distribution.) 5630 """ 5631 ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs) 5632 max_ratio = math_ops.reduce_max(ratio_l) 5633 min_ratio = math_ops.reduce_min(ratio_l) 5634 5635 # Target prob to sample from original distribution. 5636 m = min_ratio 5637 5638 # TODO(joelshor): Simplify fraction, if possible. 5639 a_i = (ratio_l - m) / (max_ratio - m) 5640 return a_i, m 5641 5642 5643class _TakeWhileDataset(UnaryUnchangedStructureDataset): 5644 """A dataset that stops iteration when `predicate` returns false.""" 5645 5646 def __init__(self, input_dataset, predicate): 5647 """See `take_while()` for details.""" 5648 5649 self._input_dataset = input_dataset 5650 wrapped_func = StructuredFunctionWrapper( 5651 predicate, self._transformation_name(), dataset=self._input_dataset) 5652 5653 if not wrapped_func.output_structure.is_compatible_with( 5654 tensor_spec.TensorSpec([], dtypes.bool)): 5655 raise ValueError("`predicate` must return a scalar boolean tensor.") 5656 5657 self._predicate = wrapped_func 5658 var_tensor = ged_ops.take_while_dataset( 5659 self._input_dataset._variant_tensor, # pylint: disable=protected-access 5660 other_arguments=self._predicate.function.captured_inputs, 5661 predicate=self._predicate.function, 5662 **self._flat_structure) 5663 super(_TakeWhileDataset, self).__init__(input_dataset, var_tensor) 5664 5665 def _functions(self): 5666 return [self._predicate] 5667 5668 def _transformation_name(self): 5669 return "Dataset.take_while()" 5670 5671 5672class _UniqueDataset(UnaryUnchangedStructureDataset): 5673 """A `Dataset` contains the unique elements from its input.""" 5674 5675 def __init__(self, input_dataset): 5676 """See `unique()` for details.""" 5677 self._input_dataset = input_dataset 5678 if get_legacy_output_types(input_dataset) not in (dtypes.int32, 5679 dtypes.int64, 5680 dtypes.string): 5681 raise TypeError( 5682 "`tf.data.Dataset.unique()` only supports inputs with a single " 5683 "`tf.int32`, `tf.int64`, or `tf.string` component.") 5684 variant_tensor = ged_ops.unique_dataset( 5685 self._input_dataset._variant_tensor, # pylint: disable=protected-access 5686 **self._flat_structure) 5687 super(_UniqueDataset, self).__init__(input_dataset, variant_tensor) 5688 5689 5690def _collect_resource_inputs(op): 5691 """Collects resource inputs for the given ops (and its variant inputs).""" 5692 5693 def _process(op_queue, seen_ops): 5694 """Processes the next element of the op queue. 5695 5696 Args: 5697 op_queue: Queue of Dataset operations to process. 5698 seen_ops: Already processed set of Operations. 5699 5700 Returns: 5701 A 2-tuple containing sets of resource handles. The first tuple entry 5702 contains read-only handles and the second entry contains read-write 5703 handles. 5704 """ 5705 5706 reads = [] 5707 writes = [] 5708 op = op_queue.pop() 5709 if op in seen_ops: 5710 return reads, writes 5711 seen_ops.add(op) 5712 # TODO(b/150139257): All resource inputs are in writes right now since we 5713 # have not updated the functional ops to set the special attribute that ACD 5714 # uses to figure out which of the op's inputs are read-only. 5715 reads, writes = acd_utils.get_read_write_resource_inputs(op) 5716 # Conservatively assume that any variant inputs are datasets. 5717 op_queue.extend(t.op for t in op.inputs if t.dtype == dtypes.variant) 5718 return reads, writes 5719 5720 op_queue = [op] 5721 seen_ops = set() 5722 all_reads = [] 5723 all_writes = [] 5724 while op_queue: 5725 reads, writes = _process(op_queue, seen_ops) 5726 all_reads.extend(reads) 5727 all_writes.extend(writes) 5728 5729 return all_reads, all_writes 5730 5731 5732class _SnapshotDataset(UnaryUnchangedStructureDataset): 5733 """A dataset that allows saving and re-use of already processed data.""" 5734 5735 def __init__(self, 5736 input_dataset, 5737 path, 5738 shard_func, 5739 compression=None, 5740 reader_func=None, 5741 pending_snapshot_expiry_seconds=None, 5742 use_legacy_function=False): 5743 5744 if reader_func is None: 5745 reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda 5746 lambda x: x, 5747 cycle_length=multiprocessing.cpu_count(), 5748 num_parallel_calls=AUTOTUNE) 5749 5750 self._input_dataset = input_dataset 5751 self._path = path 5752 self._compression = compression 5753 5754 self._reader_func = StructuredFunctionWrapper( 5755 reader_func, 5756 self._transformation_name() + ".reader_func", 5757 # Dataset of datasets of input elements 5758 input_structure=DatasetSpec(DatasetSpec(input_dataset.element_spec)), 5759 use_legacy_function=use_legacy_function) 5760 self._shard_func = StructuredFunctionWrapper( 5761 shard_func, 5762 self._transformation_name() + ".shard_func", 5763 dataset=input_dataset, 5764 use_legacy_function=use_legacy_function) 5765 5766 if ((not self._shard_func.output_structure.is_compatible_with( 5767 tensor_spec.TensorSpec([], dtypes.int32))) and 5768 (not self._shard_func.output_structure.is_compatible_with( 5769 tensor_spec.TensorSpec([], dtypes.int64)))): 5770 raise TypeError( 5771 "shard_func must return a 0-dimension tensor containing an int.") 5772 5773 variant_tensor = ged_ops.snapshot_dataset_v2( 5774 input_dataset._variant_tensor, # pylint: disable=protected-access 5775 path, 5776 self._reader_func.function.captured_inputs, 5777 self._shard_func.function.captured_inputs, 5778 compression=compression, 5779 reader_func=self._reader_func.function, 5780 shard_func=self._shard_func.function, 5781 **self._flat_structure) 5782 super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor) 5783 5784 def _functions(self): 5785 return [self._reader_func, self._shard_func] 5786 5787 def _transformation_name(self): 5788 return "Dataset.snapshot()" 5789 5790 5791class _ScanDataset(UnaryDataset): 5792 """A dataset that scans a function across its input.""" 5793 5794 def __init__(self, 5795 input_dataset, 5796 initial_state, 5797 scan_func, 5798 use_default_device=None): 5799 """See `scan()` for details.""" 5800 self._input_dataset = input_dataset 5801 self._initial_state = structure.normalize_element(initial_state) 5802 5803 # Compute initial values for the state classes, shapes and types based on 5804 # the initial state. The shapes may be refined by running `tf_scan_func` one 5805 # or more times below. 5806 self._state_structure = structure.type_spec_from_value(self._initial_state) 5807 5808 # Iteratively rerun the scan function until reaching a fixed point on 5809 # `self._state_shapes`. 5810 need_to_rerun = True 5811 while need_to_rerun: 5812 5813 wrapped_func = StructuredFunctionWrapper( 5814 scan_func, 5815 self._transformation_name(), 5816 input_structure=(self._state_structure, input_dataset.element_spec), 5817 add_to_graph=False) 5818 if not (isinstance(wrapped_func.output_types, collections_abc.Sequence) 5819 and len(wrapped_func.output_types) == 2): 5820 raise TypeError("The scan function must return a pair comprising the " 5821 "new state and the output value.") 5822 5823 new_state_classes, self._output_classes = wrapped_func.output_classes 5824 5825 # Extract and validate class information from the returned values. 5826 new_state_classes, output_classes = wrapped_func.output_classes 5827 old_state_classes = nest.map_structure( 5828 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 5829 self._state_structure) 5830 for new_state_class, old_state_class in zip( 5831 nest.flatten(new_state_classes), nest.flatten(old_state_classes)): 5832 if not issubclass(new_state_class, old_state_class): 5833 raise TypeError( 5834 "The element classes for the new state must match the initial " 5835 "state. Expected %s; got %s." % 5836 (old_state_classes, new_state_classes)) 5837 5838 # Extract and validate type information from the returned values. 5839 new_state_types, output_types = wrapped_func.output_types 5840 old_state_types = nest.map_structure( 5841 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 5842 self._state_structure) 5843 for new_state_type, old_state_type in zip( 5844 nest.flatten(new_state_types), nest.flatten(old_state_types)): 5845 if new_state_type != old_state_type: 5846 raise TypeError( 5847 "The element types for the new state must match the initial " 5848 "state. Expected %s; got %s." % 5849 (old_state_types, new_state_types)) 5850 5851 # Extract shape information from the returned values. 5852 new_state_shapes, output_shapes = wrapped_func.output_shapes 5853 old_state_shapes = nest.map_structure( 5854 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 5855 self._state_structure) 5856 self._element_spec = structure.convert_legacy_structure( 5857 output_types, output_shapes, output_classes) 5858 5859 flat_state_shapes = nest.flatten(old_state_shapes) 5860 flat_new_state_shapes = nest.flatten(new_state_shapes) 5861 weakened_state_shapes = [ 5862 original.most_specific_compatible_shape(new) 5863 for original, new in zip(flat_state_shapes, flat_new_state_shapes) 5864 ] 5865 5866 need_to_rerun = False 5867 for original_shape, weakened_shape in zip(flat_state_shapes, 5868 weakened_state_shapes): 5869 if original_shape.ndims is not None and ( 5870 weakened_shape.ndims is None or 5871 original_shape.as_list() != weakened_shape.as_list()): 5872 need_to_rerun = True 5873 break 5874 5875 if need_to_rerun: 5876 # TODO(b/110122868): Support a "most specific compatible structure" 5877 # method for combining structures, to avoid using legacy structures 5878 # in this method. 5879 self._state_structure = structure.convert_legacy_structure( 5880 old_state_types, 5881 nest.pack_sequence_as(old_state_shapes, weakened_state_shapes), 5882 old_state_classes) 5883 5884 self._scan_func = wrapped_func 5885 self._scan_func.function.add_to_graph(ops.get_default_graph()) 5886 # pylint: disable=protected-access 5887 if use_default_device is not None: 5888 variant_tensor = ged_ops.scan_dataset( 5889 self._input_dataset._variant_tensor, 5890 structure.to_tensor_list(self._state_structure, self._initial_state), 5891 self._scan_func.function.captured_inputs, 5892 f=self._scan_func.function, 5893 preserve_cardinality=True, 5894 use_default_device=use_default_device, 5895 **self._flat_structure) 5896 else: 5897 variant_tensor = ged_ops.scan_dataset( 5898 self._input_dataset._variant_tensor, 5899 structure.to_tensor_list(self._state_structure, self._initial_state), 5900 self._scan_func.function.captured_inputs, 5901 f=self._scan_func.function, 5902 preserve_cardinality=True, 5903 **self._flat_structure) 5904 super(_ScanDataset, self).__init__(input_dataset, variant_tensor) 5905 5906 def _functions(self): 5907 return [self._scan_func] 5908 5909 @property 5910 def element_spec(self): 5911 return self._element_spec 5912 5913 def _transformation_name(self): 5914 return "Dataset.scan()" 5915 5916 5917@auto_control_deps.register_acd_resource_resolver 5918def _resource_resolver(op, resource_reads, resource_writes): 5919 """Updates resource inputs for tf.data ops with indirect dependencies.""" 5920 5921 updated = False 5922 if op.type in [ 5923 "DatasetToSingleElement", "DatasetToTFRecord", "ReduceDataset" 5924 ]: 5925 reads, writes = _collect_resource_inputs(op) 5926 for inp in reads: 5927 if inp not in resource_reads: 5928 updated = True 5929 resource_reads.add(inp) 5930 for inp in writes: 5931 if inp not in resource_writes: 5932 updated = True 5933 resource_writes.add(inp) 5934 5935 if op.type in [ 5936 "IteratorGetNext", "IteratorGetNextSync", "IteratorGetNextAsOptional" 5937 ]: 5938 iterator_resource = op.inputs[0] 5939 make_iterator_ops = [ 5940 op for op in iterator_resource.consumers() if op.type == "MakeIterator" 5941 ] 5942 5943 if len(make_iterator_ops) == 1: 5944 reads, writes = _collect_resource_inputs(make_iterator_ops[0]) 5945 for inp in reads: 5946 if inp not in resource_reads: 5947 updated = True 5948 resource_reads.add(inp) 5949 for inp in writes: 5950 if inp not in resource_writes: 5951 updated = True 5952 resource_writes.add(inp) 5953 5954 return updated 5955 5956 5957DEBUG_MODE = False 5958 5959 5960@tf_export("data.experimental.enable_debug_mode") 5961def enable_debug_mode(): 5962 """Enables debug mode for tf.data. 5963 5964 Example usage with pdb module: 5965 ``` 5966 import tensorflow as tf 5967 import pdb 5968 5969 tf.data.experimental.enable_debug_mode() 5970 5971 def func(x): 5972 # Python 3.7 and older requires `pdb.Pdb(nosigint=True).set_trace()` 5973 pdb.set_trace() 5974 x = x + 1 5975 return x 5976 5977 dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 5978 dataset = dataset.map(func) 5979 5980 for item in dataset: 5981 print(item) 5982 ``` 5983 5984 The effect of debug mode is two-fold: 5985 5986 1) Any transformations that would introduce asynchrony, parallelism, or 5987 non-determinism to the input pipeline execution will be forced to execute 5988 synchronously, sequentially, and deterministically. 5989 5990 2) Any user-defined functions passed into tf.data transformations such as 5991 `map` will be wrapped in `tf.py_function` so that their body is executed 5992 "eagerly" as a Python function as opposed to a traced TensorFlow graph, which 5993 is the default behavior. Note that even when debug mode is enabled, the 5994 user-defined function is still traced to infer the shape and type of its 5995 outputs; as a consequence, any `print` statements or breakpoints will be 5996 triggered once during the tracing before the actual execution of the input 5997 pipeline. 5998 5999 NOTE: As the debug mode setting affects the construction of the tf.data input 6000 pipeline, it should be enabled before any tf.data definitions. 6001 6002 Raises: 6003 ValueError: When invoked from graph mode. 6004 """ 6005 if context.executing_eagerly(): 6006 toggle_debug_mode(True) 6007 else: 6008 raise ValueError("Debug mode is only supported in eager mode.") 6009 6010 6011def toggle_debug_mode(debug_mode): 6012 global DEBUG_MODE 6013 DEBUG_MODE = debug_mode 6014