• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Dataset`."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import warnings
23
24from absl.testing import parameterized
25import numpy as np
26
27from tensorflow.core.framework import graph_pb2
28from tensorflow.python.data.experimental.ops import distribute_options
29from tensorflow.python.data.kernel_tests import test_base
30from tensorflow.python.data.ops import dataset_ops
31from tensorflow.python.data.ops import optional_ops
32from tensorflow.python.data.ops import readers
33from tensorflow.python.data.util import nest
34from tensorflow.python.data.util import structure
35from tensorflow.python.eager import context
36from tensorflow.python.eager import def_function
37from tensorflow.python.framework import combinations
38from tensorflow.python.framework import constant_op
39from tensorflow.python.framework import dtypes
40from tensorflow.python.framework import errors
41from tensorflow.python.framework import ops
42from tensorflow.python.framework import sparse_tensor
43from tensorflow.python.framework import tensor_shape
44from tensorflow.python.framework import tensor_spec
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import random_ops
47from tensorflow.python.platform import test
48
49
50class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
51
52  @combinations.generate(test_base.default_test_combinations())
53  def testAsSerializedGraph(self):
54    dataset = dataset_ops.Dataset.range(10)
55    graph = graph_pb2.GraphDef().FromString(
56        self.evaluate(dataset._as_serialized_graph()))
57    self.assertTrue(any(node.op == "RangeDataset" for node in graph.node))
58
59  def testAsSerializedGraphStateful(self):
60    dataset = dataset_ops.Dataset.range(10).map(
61        lambda _: random_ops.random_uniform(()))
62    with self.assertRaises(errors.FailedPreconditionError):
63      self.evaluate(
64          dataset._as_serialized_graph(external_state_policy=distribute_options
65                                       .ExternalStatePolicy.FAIL))
66
67  @combinations.generate(test_base.default_test_combinations())
68  def testAsFunctionWithMap(self):
69    if not context.executing_eagerly():
70      self.skipTest("Only works executing eagerly")
71    with ops.device("CPU"):
72      original_dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
73      fn = original_dataset._trace_variant_creation()
74      variant = fn()
75
76      revived_dataset = dataset_ops._VariantDataset(
77          variant, original_dataset.element_spec)
78      self.assertDatasetProduces(revived_dataset, range(0, 10, 2))
79
80  @combinations.generate(test_base.default_test_combinations())
81  def testAsFunctionWithMapInFlatMap(self):
82    if not context.executing_eagerly():
83      self.skipTest("Only works executing eagerly")
84    with ops.device("CPU"):
85      original_dataset = dataset_ops.Dataset.range(5).flat_map(
86          lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2))
87      fn = original_dataset._trace_variant_creation()
88      variant = fn()
89
90      revived_dataset = dataset_ops._VariantDataset(
91          variant, original_dataset.element_spec)
92      self.assertDatasetProduces(revived_dataset, list(original_dataset))
93
94  def _testNumInputs(self, dataset, num_inputs):
95    self.assertLen(dataset._inputs(), num_inputs)
96
97  @combinations.generate(test_base.default_test_combinations())
98  def testFixedLengthRecordInputs(self):
99    dataset = readers.FixedLengthRecordDataset("", 42)
100    self._testNumInputs(dataset, 0)
101
102  @combinations.generate(test_base.default_test_combinations())
103  def testFromGeneratorInputs(self):
104    def gen():
105      yield 42
106
107    dataset = dataset_ops.Dataset.from_generator(gen, dtypes.int32)
108    self._testNumInputs(dataset, 1)
109
110  @combinations.generate(test_base.default_test_combinations())
111  def testFromTensorsInputs(self):
112    dataset = dataset_ops.Dataset.from_tensors([42])
113    self._testNumInputs(dataset, 0)
114
115  @combinations.generate(test_base.default_test_combinations())
116  def testRangeInputs(self):
117    dataset = dataset_ops.Dataset.range(10)
118    self._testNumInputs(dataset, 0)
119
120  @combinations.generate(test_base.default_test_combinations())
121  def testTextLineInputs(self):
122    dataset = readers.TextLineDataset("")
123    self._testNumInputs(dataset, 0)
124
125  @combinations.generate(test_base.default_test_combinations())
126  def testTFRecordInputs(self):
127    dataset = readers.TFRecordDataset("")
128    self._testNumInputs(dataset, 1)
129
130  @combinations.generate(
131      combinations.combine(tf_api_version=1, mode=["eager", "graph"]))
132  def testDatasetComplexSourceInputs(self):
133    dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
134        sparse_tensor.SparseTensor(
135            indices=np.array([[0, 0], [1, 0], [2, 0]]),
136            values=np.array([0, 0, 0]),
137            dense_shape=np.array([3, 1])))
138    self.assertEmpty(dataset_fn._inputs())
139
140  def _testUnaryInputs(self, dataset_fn):
141    input_dataset = dataset_ops.Dataset.range(0)
142    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
143
144  @combinations.generate(test_base.default_test_combinations())
145  def testBatchInputs(self):
146    self._testUnaryInputs(lambda x: x.batch(10))
147
148  @combinations.generate(test_base.default_test_combinations())
149  def testCacheInputs(self):
150    self._testUnaryInputs(lambda x: x.cache())
151
152  @combinations.generate(test_base.default_test_combinations())
153  def testFilterInputs(self):
154    self._testUnaryInputs(lambda x: x.filter(lambda x: True))
155
156  @combinations.generate(test_base.default_test_combinations())
157  def testFlatMapInputs(self):
158    self._testUnaryInputs(
159        lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)))
160
161  @combinations.generate(test_base.default_test_combinations())
162  def testMapInputs(self):
163    self._testUnaryInputs(lambda x: x.map(lambda x: x))
164
165  @combinations.generate(test_base.default_test_combinations())
166  def testPaddedBatchInputs(self):
167    self._testUnaryInputs(lambda x: x.padded_batch(10, []))
168
169  @combinations.generate(test_base.default_test_combinations())
170  def testParallelMapInputs(self):
171    self._testUnaryInputs(lambda x: x.map(lambda x: x, num_parallel_calls=2))
172
173  @combinations.generate(test_base.default_test_combinations())
174  def testRepeatInputs(self):
175    self._testUnaryInputs(lambda x: x.repeat())
176
177  @combinations.generate(test_base.default_test_combinations())
178  def testShuffleInputs(self):
179    self._testUnaryInputs(lambda x: x.shuffle(10))
180
181  @combinations.generate(test_base.default_test_combinations())
182  def testSkipInputs(self):
183    self._testUnaryInputs(lambda x: x.skip(1))
184
185  @combinations.generate(test_base.default_test_combinations())
186  def testTakeInputs(self):
187    self._testUnaryInputs(lambda x: x.take(1))
188
189  @combinations.generate(test_base.default_test_combinations())
190  def testWindowInputs(self):
191    self._testUnaryInputs(lambda x: x.window(10))
192
193  @combinations.generate(test_base.default_test_combinations())
194  def testUnaryTransformationInputsApply(self):
195    input_dataset = dataset_ops.Dataset.range(0)
196    dataset = input_dataset.apply(lambda dataset: dataset.cache())
197
198    self.assertEqual([input_dataset], dataset._inputs())
199
200  def _testInputsWithInterleaveFn(self, dataset_fn, interleave_parallelism):
201    input_dataset = dataset_ops.Dataset.range(0)
202    dataset = input_dataset.interleave(
203        lambda x: dataset_ops.Dataset.range(0),
204        cycle_length=2,
205        num_parallel_calls=interleave_parallelism)
206    self.assertEqual([input_dataset], dataset._inputs())
207
208  @combinations.generate(test_base.default_test_combinations())
209  def testParallelInterleaveInputs(self):
210    self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), 2)
211
212  @combinations.generate(test_base.default_test_combinations())
213  def testInterleaveInputs(self):
214    self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), None)
215
216  @combinations.generate(test_base.default_test_combinations())
217  def testNoWarnings(self):
218    with test.mock.patch.object(warnings, "warn") as mock_log:
219      dataset_ops.Dataset.range(0).interleave(
220          lambda x: dataset_ops.Dataset.range(0), cycle_length=2)
221      self.assertEmpty(mock_log.call_args_list)
222
223  def _testBinaryInputs(self, dataset_fn):
224    input1 = dataset_ops.Dataset.range(0)
225    input2 = dataset_ops.Dataset.range(1)
226    self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())
227
228  @combinations.generate(test_base.default_test_combinations())
229  def testConcatenateInputs(self):
230    self._testBinaryInputs(lambda x, y: x.concatenate(y))
231
232  def _testVariadicInputs(self, dataset_fn, input_datasets):
233    self.assertEqual(
234        nest.flatten(input_datasets),
235        dataset_fn(input_datasets)._inputs())
236
237  @combinations.generate(test_base.default_test_combinations())
238  def testZipOneInputs(self):
239    input_datasets = dataset_ops.Dataset.range(0)
240    self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
241
242  @combinations.generate(test_base.default_test_combinations())
243  def testZipNestInputs(self):
244    input_datasets = (dataset_ops.Dataset.range(0),
245                      (dataset_ops.Dataset.range(1),
246                       dataset_ops.Dataset.range(2)))
247    self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
248
249  @combinations.generate(test_base.default_test_combinations())
250  def testZipTupleInputs(self):
251    input_datasets = (dataset_ops.Dataset.range(0),
252                      dataset_ops.Dataset.range(1))
253    self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
254
255  @combinations.generate(test_base.default_test_combinations())
256  def testFunctions(self):
257    dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
258    self.assertLen(dataset._functions(), 1)
259
260  @combinations.generate(test_base.default_test_combinations())
261  def testCollectInputs(self):
262    ds1 = dataset_ops.Dataset.range(0)
263    ds2 = ds1.concatenate(ds1)
264    ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))
265
266    inputs = []
267    queue = [ds3]
268    while queue:
269      ds = queue[0]
270      queue = queue[1:]
271      queue.extend(ds._inputs())
272      inputs.append(ds)
273
274    self.assertEqual(5, inputs.count(ds1))
275    self.assertEqual(2, inputs.count(ds2))
276    self.assertEqual(1, inputs.count(ds3))
277
278  def _testDatasetSpec(self, tf_value, expected_element_structure):
279    dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value)
280    dataset_structure = structure.type_spec_from_value(dataset)
281    self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec)
282
283    self.assertTrue(
284        structure.are_compatible(
285            dataset_ops.get_structure(dataset), expected_element_structure))
286    self.assertEqual([dtypes.variant],
287                     structure.get_flat_tensor_types(dataset_structure))
288    self.assertEqual([tensor_shape.TensorShape([])],
289                     structure.get_flat_tensor_shapes(dataset_structure))
290
291    # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
292    # and _to_tensor_list().
293    round_trip_dataset = dataset_structure._from_tensor_list(
294        dataset_structure._to_tensor_list(dataset))
295
296    value = tf_value
297
298    if isinstance(value, dataset_ops.Dataset):
299      self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
300    elif isinstance(value, optional_ops.Optional):
301      self.assertDatasetProduces(
302          round_trip_dataset.map(lambda opt: opt.get_value()),
303          [self.evaluate(value.get_value())],
304          requires_initialization=True)
305    else:
306      self.assertDatasetProduces(
307          round_trip_dataset, [self.evaluate(tf_value)],
308          requires_initialization=True)
309
310  @combinations.generate(test_base.default_test_combinations())
311  def testTensorDatasetSpec(self):
312    self._testDatasetSpec(
313        constant_op.constant(37.0), tensor_spec.TensorSpec([], dtypes.float32))
314
315  @combinations.generate(test_base.default_test_combinations())
316  def testSparseTensorDatasetSpec(self):
317    self._testDatasetSpec(
318        sparse_tensor.SparseTensor(
319            indices=[[0]],
320            values=constant_op.constant([0], dtype=dtypes.int32),
321            dense_shape=[1]), sparse_tensor.SparseTensorSpec([1], dtypes.int32))
322
323  @combinations.generate(test_base.default_test_combinations())
324  def testNestDatasetSpec(self):
325    self._testDatasetSpec(
326        {
327            "a": constant_op.constant(37.0),
328            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
329        }, {
330            "a":
331                tensor_spec.TensorSpec([], dtypes.float32),
332            "b": (
333                tensor_spec.TensorSpec([1], dtypes.string),
334                tensor_spec.TensorSpec([], dtypes.string),
335            )
336        })
337
338  @combinations.generate(test_base.default_test_combinations())
339  def testDatasetDatasetSpec(self):
340    self._testDatasetSpec(
341        dataset_ops.Dataset.from_tensor_slices(
342            constant_op.constant([1, 2, 3])),
343        dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32)))
344
345  @combinations.generate(test_base.default_test_combinations())
346  def testOptionalDatasetSpec(self):
347    self._testDatasetSpec(
348        optional_ops.Optional.from_value(37.0),
349        optional_ops.OptionalSpec(tensor_spec.TensorSpec([], dtypes.float32)))
350
351  @combinations.generate(test_base.graph_only_combinations())
352  def testSameGraphError(self):
353    dataset = dataset_ops.Dataset.range(10)
354    with ops.Graph().as_default():
355      with self.assertRaisesRegex(ValueError, "must be from the same graph"):
356        dataset = dataset.batch(2)
357
358  @combinations.generate(
359      combinations.combine(tf_api_version=[1], mode=["graph"]))
360  def testSameGraphErrorOneShot(self):
361    dataset = dataset_ops.Dataset.range(10)
362    with ops.Graph().as_default():
363      with self.assertRaisesRegex(
364          ValueError, "Please ensure that all datasets in the pipeline are "
365          "created in the same graph as the iterator."):
366        _ = dataset_ops.make_one_shot_iterator(dataset)
367
368  @combinations.generate(
369      combinations.combine(tf_api_version=[1], mode=["graph"]))
370  def testSameGraphErrorInitializable(self):
371    dataset = dataset_ops.Dataset.range(10)
372    with ops.Graph().as_default():
373      with self.assertRaisesRegex(
374          ValueError, "Please ensure that all datasets in the pipeline are "
375          "created in the same graph as the iterator."):
376        _ = dataset_ops.make_initializable_iterator(dataset)
377
378  @combinations.generate(
379      combinations.times(
380          test_base.eager_only_combinations(),
381          combinations.combine(execution_mode=[context.ASYNC, context.SYNC])))
382  def testEagerIteration(self, execution_mode):
383    with context.execution_mode(execution_mode):
384      val = 0
385      dataset = dataset_ops.Dataset.range(10)
386      for foo in dataset:
387        self.assertEqual(val, foo.numpy())
388        val += 1
389
390  @combinations.generate(test_base.default_test_combinations())
391  def testDatasetAsFunctionArgument(self):
392
393    @def_function.function
394    def _uses_dataset(d):
395      accumulator = array_ops.zeros([], dtype=dtypes.int64)
396      for value in d:
397        accumulator += value
398      return accumulator
399
400    with ops.device("CPU"):
401      first_dataset = dataset_ops.Dataset.range(10)
402      self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset)))
403      second_dataset = dataset_ops.Dataset.range(11)
404      self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset)))
405      first_concrete = _uses_dataset.get_concrete_function(first_dataset)
406      # The dataset should not be a captured input
407      self.assertEmpty(first_concrete.graph.captures)
408      # The two datasets have the same structure and so should re-use a trace.
409      self.assertIs(first_concrete,
410                    _uses_dataset.get_concrete_function(second_dataset))
411      # With a different structure we should use a different trace.
412      self.assertIsNot(
413          first_concrete,
414          _uses_dataset.get_concrete_function(
415              dataset_ops.Dataset.zip((first_dataset, second_dataset))))
416
417  @combinations.generate(test_base.default_test_combinations())
418  def testLimitedRetracing(self):
419    trace_count = [0]
420
421    @def_function.function
422    def f(ds):
423      trace_count[0] += 1
424      counter = np.int64(0)
425      for elem in ds:
426        counter += elem
427      return counter
428
429    dataset = dataset_ops.Dataset.range(5)
430    dataset2 = dataset_ops.Dataset.range(10)
431
432    for _ in range(10):
433      self.assertEqual(self.evaluate(f(dataset)), 10)
434      self.assertEqual(self.evaluate(f(dataset2)), 45)
435      self.assertEqual(trace_count[0], 1)
436
437  # pylint: disable=g-long-lambda,unnecessary-lambda
438  @combinations.generate(test_base.default_test_combinations())
439  def testLegacyStructureAPI(self):
440    components = (np.array([1, 2, 3], dtype=np.int64), (np.array([4., 5.]),
441                                                        np.array([6., 7.])),
442                  np.array([8, 9, 10], dtype=np.int64))
443
444    dataset = dataset_ops.Dataset.from_tensors(components)
445    self.assertEqual(
446        (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
447        dataset_ops.get_legacy_output_types(dataset))
448    self.assertEqual(([3], ([2], [2]), [3]),
449                     dataset_ops.get_legacy_output_shapes(dataset))
450
451    dataset = dataset.shuffle(10, 10)
452    self.assertEqual(
453        (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
454        dataset_ops.get_legacy_output_types(dataset))
455    self.assertEqual(([3], ([2], [2]), [3]),
456                     dataset_ops.get_legacy_output_shapes(dataset))
457
458    dataset = dataset.repeat(-1)
459    self.assertEqual(
460        (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
461        dataset_ops.get_legacy_output_types(dataset))
462    self.assertEqual(([3], ([2], [2]), [3]),
463                     dataset_ops.get_legacy_output_shapes(dataset))
464
465    dataset = dataset.filter(lambda x, y, z: True)
466    self.assertEqual(
467        (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
468        dataset_ops.get_legacy_output_types(dataset))
469    self.assertEqual(([3], ([2], [2]), [3]),
470                     dataset_ops.get_legacy_output_shapes(dataset))
471
472    dataset = dataset.take(5)
473    self.assertEqual(
474        (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
475        dataset_ops.get_legacy_output_types(dataset))
476    self.assertEqual(([3], ([2], [2]), [3]),
477                     dataset_ops.get_legacy_output_shapes(dataset))
478
479    dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1])))
480    self.assertEqual(
481        ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
482        dataset_ops.get_legacy_output_types(dataset))
483    self.assertEqual((([3], [3]), ([2], [2])),
484                     dataset_ops.get_legacy_output_shapes(dataset))
485
486    dataset = dataset.flat_map(lambda x, y: dataset_ops.Dataset.from_tensors(
487        ((x[0], x[1]), (y[0], y[1]))))
488    self.assertEqual(
489        ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
490        dataset_ops.get_legacy_output_types(dataset))
491    self.assertEqual((([3], [3]), ([2], [2])),
492                     dataset_ops.get_legacy_output_shapes(dataset))
493
494    dataset = dataset.batch(32)
495    self.assertEqual(
496        ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
497        dataset_ops.get_legacy_output_types(dataset))
498    dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset)
499    self.assertEqual(
500        (([None, 3], [None, 3]), ([None, 2], [None, 2])),
501        nest.pack_sequence_as(
502            dataset_output_shapes,
503            [s.as_list() for s in nest.flatten(dataset_output_shapes)]))
504
505    # Define a separate set of components with matching leading
506    # dimension for the from-slices constructor.
507    components_for_slices = (np.array([1, 2, 3],
508                                      dtype=np.int64), (np.array([4., 5., 6.]),
509                                                        np.array([7., 8., 9.])),
510                             np.array([10, 11, 12], dtype=np.int64))
511
512    dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices)
513    self.assertEqual(
514        (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
515        dataset_ops.get_legacy_output_types(dataset))
516    self.assertEqual(([], ([], []), []),
517                     dataset_ops.get_legacy_output_shapes(dataset))
518
519  @combinations.generate(test_base.default_test_combinations())
520  def testNoneComponent(self):
521    dataset = dataset_ops.Dataset.from_tensors((42, None))
522    if context.executing_eagerly():
523      self.assertDatasetProduces(dataset, expected_output=[(42, None)])
524    else:
525      iterator = dataset_ops.make_one_shot_iterator(dataset)
526      next_first, next_second = iterator.get_next()
527      self.assertEqual(next_second, None)
528      with self.cached_session() as sess:
529        self.assertEqual(sess.run(next_first), 42)
530
531  @combinations.generate(test_base.default_test_combinations())
532  def testNoneComponentInFunction(self):
533
534    @def_function.function
535    def fn(ds):
536      total = 0
537      it = iter(ds)
538      for elem in it:
539        x, _ = elem
540        total += x
541      return total
542
543    dataset = dataset_ops.Dataset.range(
544        10, output_type=dtypes.int32).map(lambda x: (x, None))
545    self.assertEqual(self.evaluate(fn(dataset)), 45)
546
547  @combinations.generate(test_base.default_test_combinations())
548  def testIncorrectPythonStructure(self):
549    # Tests that an exception is raised (as opposed to a segfault) when the
550    # Python structure assigned to a dataset is incorrect.
551    dataset = dataset_ops.Dataset.range(10)
552    spec = tensor_spec.TensorSpec([], dtypes.int64)
553    new_structure = (spec, spec)
554    dataset = dataset_ops._RestructuredDataset(dataset, new_structure)
555    dataset = dataset.map(lambda x, y: y)
556
557    with self.assertRaisesOpError(""):
558      self.getDatasetOutput(dataset)
559
560  def testNamedTupleStructure(self):
561    Foo = collections.namedtuple("Foo", ["a", "b"])
562    x = Foo(a=3, b="test")
563    dataset = dataset_ops.Dataset.from_tensors(x)
564    dataset = dataset_ops.Dataset.from_tensor_slices([dataset, dataset])
565    self.assertEqual(
566        str(dataset.element_spec),
567        "DatasetSpec(Foo(a=TensorSpec(shape=(), dtype=tf.int32, name=None), "
568        "b=TensorSpec(shape=(), dtype=tf.string, name=None)), TensorShape([]))")
569
570
571if __name__ == "__main__":
572  test.main()
573