1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Dataset`.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import warnings 23 24from absl.testing import parameterized 25import numpy as np 26 27from tensorflow.core.framework import graph_pb2 28from tensorflow.python.data.experimental.ops import distribute_options 29from tensorflow.python.data.kernel_tests import test_base 30from tensorflow.python.data.ops import dataset_ops 31from tensorflow.python.data.ops import optional_ops 32from tensorflow.python.data.ops import readers 33from tensorflow.python.data.util import nest 34from tensorflow.python.data.util import structure 35from tensorflow.python.eager import context 36from tensorflow.python.eager import def_function 37from tensorflow.python.framework import combinations 38from tensorflow.python.framework import constant_op 39from tensorflow.python.framework import dtypes 40from tensorflow.python.framework import errors 41from tensorflow.python.framework import ops 42from tensorflow.python.framework import sparse_tensor 43from tensorflow.python.framework import tensor_shape 44from tensorflow.python.framework import tensor_spec 45from tensorflow.python.ops import array_ops 46from tensorflow.python.ops import random_ops 47from tensorflow.python.platform import test 48 49 50class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): 51 52 @combinations.generate(test_base.default_test_combinations()) 53 def testAsSerializedGraph(self): 54 dataset = dataset_ops.Dataset.range(10) 55 graph = graph_pb2.GraphDef().FromString( 56 self.evaluate(dataset._as_serialized_graph())) 57 self.assertTrue(any(node.op == "RangeDataset" for node in graph.node)) 58 59 def testAsSerializedGraphStateful(self): 60 dataset = dataset_ops.Dataset.range(10).map( 61 lambda _: random_ops.random_uniform(())) 62 with self.assertRaises(errors.FailedPreconditionError): 63 self.evaluate( 64 dataset._as_serialized_graph(external_state_policy=distribute_options 65 .ExternalStatePolicy.FAIL)) 66 67 @combinations.generate(test_base.default_test_combinations()) 68 def testAsFunctionWithMap(self): 69 if not context.executing_eagerly(): 70 self.skipTest("Only works executing eagerly") 71 with ops.device("CPU"): 72 original_dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2) 73 fn = original_dataset._trace_variant_creation() 74 variant = fn() 75 76 revived_dataset = dataset_ops._VariantDataset( 77 variant, original_dataset.element_spec) 78 self.assertDatasetProduces(revived_dataset, range(0, 10, 2)) 79 80 @combinations.generate(test_base.default_test_combinations()) 81 def testAsFunctionWithMapInFlatMap(self): 82 if not context.executing_eagerly(): 83 self.skipTest("Only works executing eagerly") 84 with ops.device("CPU"): 85 original_dataset = dataset_ops.Dataset.range(5).flat_map( 86 lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2)) 87 fn = original_dataset._trace_variant_creation() 88 variant = fn() 89 90 revived_dataset = dataset_ops._VariantDataset( 91 variant, original_dataset.element_spec) 92 self.assertDatasetProduces(revived_dataset, list(original_dataset)) 93 94 def _testNumInputs(self, dataset, num_inputs): 95 self.assertLen(dataset._inputs(), num_inputs) 96 97 @combinations.generate(test_base.default_test_combinations()) 98 def testFixedLengthRecordInputs(self): 99 dataset = readers.FixedLengthRecordDataset("", 42) 100 self._testNumInputs(dataset, 0) 101 102 @combinations.generate(test_base.default_test_combinations()) 103 def testFromGeneratorInputs(self): 104 def gen(): 105 yield 42 106 107 dataset = dataset_ops.Dataset.from_generator(gen, dtypes.int32) 108 self._testNumInputs(dataset, 1) 109 110 @combinations.generate(test_base.default_test_combinations()) 111 def testFromTensorsInputs(self): 112 dataset = dataset_ops.Dataset.from_tensors([42]) 113 self._testNumInputs(dataset, 0) 114 115 @combinations.generate(test_base.default_test_combinations()) 116 def testRangeInputs(self): 117 dataset = dataset_ops.Dataset.range(10) 118 self._testNumInputs(dataset, 0) 119 120 @combinations.generate(test_base.default_test_combinations()) 121 def testTextLineInputs(self): 122 dataset = readers.TextLineDataset("") 123 self._testNumInputs(dataset, 0) 124 125 @combinations.generate(test_base.default_test_combinations()) 126 def testTFRecordInputs(self): 127 dataset = readers.TFRecordDataset("") 128 self._testNumInputs(dataset, 1) 129 130 @combinations.generate( 131 combinations.combine(tf_api_version=1, mode=["eager", "graph"])) 132 def testDatasetComplexSourceInputs(self): 133 dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices( 134 sparse_tensor.SparseTensor( 135 indices=np.array([[0, 0], [1, 0], [2, 0]]), 136 values=np.array([0, 0, 0]), 137 dense_shape=np.array([3, 1]))) 138 self.assertEmpty(dataset_fn._inputs()) 139 140 def _testUnaryInputs(self, dataset_fn): 141 input_dataset = dataset_ops.Dataset.range(0) 142 self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) 143 144 @combinations.generate(test_base.default_test_combinations()) 145 def testBatchInputs(self): 146 self._testUnaryInputs(lambda x: x.batch(10)) 147 148 @combinations.generate(test_base.default_test_combinations()) 149 def testCacheInputs(self): 150 self._testUnaryInputs(lambda x: x.cache()) 151 152 @combinations.generate(test_base.default_test_combinations()) 153 def testFilterInputs(self): 154 self._testUnaryInputs(lambda x: x.filter(lambda x: True)) 155 156 @combinations.generate(test_base.default_test_combinations()) 157 def testFlatMapInputs(self): 158 self._testUnaryInputs( 159 lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0))) 160 161 @combinations.generate(test_base.default_test_combinations()) 162 def testMapInputs(self): 163 self._testUnaryInputs(lambda x: x.map(lambda x: x)) 164 165 @combinations.generate(test_base.default_test_combinations()) 166 def testPaddedBatchInputs(self): 167 self._testUnaryInputs(lambda x: x.padded_batch(10, [])) 168 169 @combinations.generate(test_base.default_test_combinations()) 170 def testParallelMapInputs(self): 171 self._testUnaryInputs(lambda x: x.map(lambda x: x, num_parallel_calls=2)) 172 173 @combinations.generate(test_base.default_test_combinations()) 174 def testRepeatInputs(self): 175 self._testUnaryInputs(lambda x: x.repeat()) 176 177 @combinations.generate(test_base.default_test_combinations()) 178 def testShuffleInputs(self): 179 self._testUnaryInputs(lambda x: x.shuffle(10)) 180 181 @combinations.generate(test_base.default_test_combinations()) 182 def testSkipInputs(self): 183 self._testUnaryInputs(lambda x: x.skip(1)) 184 185 @combinations.generate(test_base.default_test_combinations()) 186 def testTakeInputs(self): 187 self._testUnaryInputs(lambda x: x.take(1)) 188 189 @combinations.generate(test_base.default_test_combinations()) 190 def testWindowInputs(self): 191 self._testUnaryInputs(lambda x: x.window(10)) 192 193 @combinations.generate(test_base.default_test_combinations()) 194 def testUnaryTransformationInputsApply(self): 195 input_dataset = dataset_ops.Dataset.range(0) 196 dataset = input_dataset.apply(lambda dataset: dataset.cache()) 197 198 self.assertEqual([input_dataset], dataset._inputs()) 199 200 def _testInputsWithInterleaveFn(self, dataset_fn, interleave_parallelism): 201 input_dataset = dataset_ops.Dataset.range(0) 202 dataset = input_dataset.interleave( 203 lambda x: dataset_ops.Dataset.range(0), 204 cycle_length=2, 205 num_parallel_calls=interleave_parallelism) 206 self.assertEqual([input_dataset], dataset._inputs()) 207 208 @combinations.generate(test_base.default_test_combinations()) 209 def testParallelInterleaveInputs(self): 210 self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), 2) 211 212 @combinations.generate(test_base.default_test_combinations()) 213 def testInterleaveInputs(self): 214 self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), None) 215 216 @combinations.generate(test_base.default_test_combinations()) 217 def testNoWarnings(self): 218 with test.mock.patch.object(warnings, "warn") as mock_log: 219 dataset_ops.Dataset.range(0).interleave( 220 lambda x: dataset_ops.Dataset.range(0), cycle_length=2) 221 self.assertEmpty(mock_log.call_args_list) 222 223 def _testBinaryInputs(self, dataset_fn): 224 input1 = dataset_ops.Dataset.range(0) 225 input2 = dataset_ops.Dataset.range(1) 226 self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) 227 228 @combinations.generate(test_base.default_test_combinations()) 229 def testConcatenateInputs(self): 230 self._testBinaryInputs(lambda x, y: x.concatenate(y)) 231 232 def _testVariadicInputs(self, dataset_fn, input_datasets): 233 self.assertEqual( 234 nest.flatten(input_datasets), 235 dataset_fn(input_datasets)._inputs()) 236 237 @combinations.generate(test_base.default_test_combinations()) 238 def testZipOneInputs(self): 239 input_datasets = dataset_ops.Dataset.range(0) 240 self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets) 241 242 @combinations.generate(test_base.default_test_combinations()) 243 def testZipNestInputs(self): 244 input_datasets = (dataset_ops.Dataset.range(0), 245 (dataset_ops.Dataset.range(1), 246 dataset_ops.Dataset.range(2))) 247 self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets) 248 249 @combinations.generate(test_base.default_test_combinations()) 250 def testZipTupleInputs(self): 251 input_datasets = (dataset_ops.Dataset.range(0), 252 dataset_ops.Dataset.range(1)) 253 self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets) 254 255 @combinations.generate(test_base.default_test_combinations()) 256 def testFunctions(self): 257 dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2) 258 self.assertLen(dataset._functions(), 1) 259 260 @combinations.generate(test_base.default_test_combinations()) 261 def testCollectInputs(self): 262 ds1 = dataset_ops.Dataset.range(0) 263 ds2 = ds1.concatenate(ds1) 264 ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) 265 266 inputs = [] 267 queue = [ds3] 268 while queue: 269 ds = queue[0] 270 queue = queue[1:] 271 queue.extend(ds._inputs()) 272 inputs.append(ds) 273 274 self.assertEqual(5, inputs.count(ds1)) 275 self.assertEqual(2, inputs.count(ds2)) 276 self.assertEqual(1, inputs.count(ds3)) 277 278 def _testDatasetSpec(self, tf_value, expected_element_structure): 279 dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value) 280 dataset_structure = structure.type_spec_from_value(dataset) 281 self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec) 282 283 self.assertTrue( 284 structure.are_compatible( 285 dataset_ops.get_structure(dataset), expected_element_structure)) 286 self.assertEqual([dtypes.variant], 287 structure.get_flat_tensor_types(dataset_structure)) 288 self.assertEqual([tensor_shape.TensorShape([])], 289 structure.get_flat_tensor_shapes(dataset_structure)) 290 291 # Assert that the `Dataset` survives a round-trip via _from_tensor_list() 292 # and _to_tensor_list(). 293 round_trip_dataset = dataset_structure._from_tensor_list( 294 dataset_structure._to_tensor_list(dataset)) 295 296 value = tf_value 297 298 if isinstance(value, dataset_ops.Dataset): 299 self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) 300 elif isinstance(value, optional_ops.Optional): 301 self.assertDatasetProduces( 302 round_trip_dataset.map(lambda opt: opt.get_value()), 303 [self.evaluate(value.get_value())], 304 requires_initialization=True) 305 else: 306 self.assertDatasetProduces( 307 round_trip_dataset, [self.evaluate(tf_value)], 308 requires_initialization=True) 309 310 @combinations.generate(test_base.default_test_combinations()) 311 def testTensorDatasetSpec(self): 312 self._testDatasetSpec( 313 constant_op.constant(37.0), tensor_spec.TensorSpec([], dtypes.float32)) 314 315 @combinations.generate(test_base.default_test_combinations()) 316 def testSparseTensorDatasetSpec(self): 317 self._testDatasetSpec( 318 sparse_tensor.SparseTensor( 319 indices=[[0]], 320 values=constant_op.constant([0], dtype=dtypes.int32), 321 dense_shape=[1]), sparse_tensor.SparseTensorSpec([1], dtypes.int32)) 322 323 @combinations.generate(test_base.default_test_combinations()) 324 def testNestDatasetSpec(self): 325 self._testDatasetSpec( 326 { 327 "a": constant_op.constant(37.0), 328 "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) 329 }, { 330 "a": 331 tensor_spec.TensorSpec([], dtypes.float32), 332 "b": ( 333 tensor_spec.TensorSpec([1], dtypes.string), 334 tensor_spec.TensorSpec([], dtypes.string), 335 ) 336 }) 337 338 @combinations.generate(test_base.default_test_combinations()) 339 def testDatasetDatasetSpec(self): 340 self._testDatasetSpec( 341 dataset_ops.Dataset.from_tensor_slices( 342 constant_op.constant([1, 2, 3])), 343 dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32))) 344 345 @combinations.generate(test_base.default_test_combinations()) 346 def testOptionalDatasetSpec(self): 347 self._testDatasetSpec( 348 optional_ops.Optional.from_value(37.0), 349 optional_ops.OptionalSpec(tensor_spec.TensorSpec([], dtypes.float32))) 350 351 @combinations.generate(test_base.graph_only_combinations()) 352 def testSameGraphError(self): 353 dataset = dataset_ops.Dataset.range(10) 354 with ops.Graph().as_default(): 355 with self.assertRaisesRegex(ValueError, "must be from the same graph"): 356 dataset = dataset.batch(2) 357 358 @combinations.generate( 359 combinations.combine(tf_api_version=[1], mode=["graph"])) 360 def testSameGraphErrorOneShot(self): 361 dataset = dataset_ops.Dataset.range(10) 362 with ops.Graph().as_default(): 363 with self.assertRaisesRegex( 364 ValueError, "Please ensure that all datasets in the pipeline are " 365 "created in the same graph as the iterator."): 366 _ = dataset_ops.make_one_shot_iterator(dataset) 367 368 @combinations.generate( 369 combinations.combine(tf_api_version=[1], mode=["graph"])) 370 def testSameGraphErrorInitializable(self): 371 dataset = dataset_ops.Dataset.range(10) 372 with ops.Graph().as_default(): 373 with self.assertRaisesRegex( 374 ValueError, "Please ensure that all datasets in the pipeline are " 375 "created in the same graph as the iterator."): 376 _ = dataset_ops.make_initializable_iterator(dataset) 377 378 @combinations.generate( 379 combinations.times( 380 test_base.eager_only_combinations(), 381 combinations.combine(execution_mode=[context.ASYNC, context.SYNC]))) 382 def testEagerIteration(self, execution_mode): 383 with context.execution_mode(execution_mode): 384 val = 0 385 dataset = dataset_ops.Dataset.range(10) 386 for foo in dataset: 387 self.assertEqual(val, foo.numpy()) 388 val += 1 389 390 @combinations.generate(test_base.default_test_combinations()) 391 def testDatasetAsFunctionArgument(self): 392 393 @def_function.function 394 def _uses_dataset(d): 395 accumulator = array_ops.zeros([], dtype=dtypes.int64) 396 for value in d: 397 accumulator += value 398 return accumulator 399 400 with ops.device("CPU"): 401 first_dataset = dataset_ops.Dataset.range(10) 402 self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset))) 403 second_dataset = dataset_ops.Dataset.range(11) 404 self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset))) 405 first_concrete = _uses_dataset.get_concrete_function(first_dataset) 406 # The dataset should not be a captured input 407 self.assertEmpty(first_concrete.graph.captures) 408 # The two datasets have the same structure and so should re-use a trace. 409 self.assertIs(first_concrete, 410 _uses_dataset.get_concrete_function(second_dataset)) 411 # With a different structure we should use a different trace. 412 self.assertIsNot( 413 first_concrete, 414 _uses_dataset.get_concrete_function( 415 dataset_ops.Dataset.zip((first_dataset, second_dataset)))) 416 417 @combinations.generate(test_base.default_test_combinations()) 418 def testLimitedRetracing(self): 419 trace_count = [0] 420 421 @def_function.function 422 def f(ds): 423 trace_count[0] += 1 424 counter = np.int64(0) 425 for elem in ds: 426 counter += elem 427 return counter 428 429 dataset = dataset_ops.Dataset.range(5) 430 dataset2 = dataset_ops.Dataset.range(10) 431 432 for _ in range(10): 433 self.assertEqual(self.evaluate(f(dataset)), 10) 434 self.assertEqual(self.evaluate(f(dataset2)), 45) 435 self.assertEqual(trace_count[0], 1) 436 437 # pylint: disable=g-long-lambda,unnecessary-lambda 438 @combinations.generate(test_base.default_test_combinations()) 439 def testLegacyStructureAPI(self): 440 components = (np.array([1, 2, 3], dtype=np.int64), (np.array([4., 5.]), 441 np.array([6., 7.])), 442 np.array([8, 9, 10], dtype=np.int64)) 443 444 dataset = dataset_ops.Dataset.from_tensors(components) 445 self.assertEqual( 446 (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), 447 dataset_ops.get_legacy_output_types(dataset)) 448 self.assertEqual(([3], ([2], [2]), [3]), 449 dataset_ops.get_legacy_output_shapes(dataset)) 450 451 dataset = dataset.shuffle(10, 10) 452 self.assertEqual( 453 (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), 454 dataset_ops.get_legacy_output_types(dataset)) 455 self.assertEqual(([3], ([2], [2]), [3]), 456 dataset_ops.get_legacy_output_shapes(dataset)) 457 458 dataset = dataset.repeat(-1) 459 self.assertEqual( 460 (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), 461 dataset_ops.get_legacy_output_types(dataset)) 462 self.assertEqual(([3], ([2], [2]), [3]), 463 dataset_ops.get_legacy_output_shapes(dataset)) 464 465 dataset = dataset.filter(lambda x, y, z: True) 466 self.assertEqual( 467 (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), 468 dataset_ops.get_legacy_output_types(dataset)) 469 self.assertEqual(([3], ([2], [2]), [3]), 470 dataset_ops.get_legacy_output_shapes(dataset)) 471 472 dataset = dataset.take(5) 473 self.assertEqual( 474 (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), 475 dataset_ops.get_legacy_output_types(dataset)) 476 self.assertEqual(([3], ([2], [2]), [3]), 477 dataset_ops.get_legacy_output_shapes(dataset)) 478 479 dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1]))) 480 self.assertEqual( 481 ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)), 482 dataset_ops.get_legacy_output_types(dataset)) 483 self.assertEqual((([3], [3]), ([2], [2])), 484 dataset_ops.get_legacy_output_shapes(dataset)) 485 486 dataset = dataset.flat_map(lambda x, y: dataset_ops.Dataset.from_tensors( 487 ((x[0], x[1]), (y[0], y[1])))) 488 self.assertEqual( 489 ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)), 490 dataset_ops.get_legacy_output_types(dataset)) 491 self.assertEqual((([3], [3]), ([2], [2])), 492 dataset_ops.get_legacy_output_shapes(dataset)) 493 494 dataset = dataset.batch(32) 495 self.assertEqual( 496 ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)), 497 dataset_ops.get_legacy_output_types(dataset)) 498 dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset) 499 self.assertEqual( 500 (([None, 3], [None, 3]), ([None, 2], [None, 2])), 501 nest.pack_sequence_as( 502 dataset_output_shapes, 503 [s.as_list() for s in nest.flatten(dataset_output_shapes)])) 504 505 # Define a separate set of components with matching leading 506 # dimension for the from-slices constructor. 507 components_for_slices = (np.array([1, 2, 3], 508 dtype=np.int64), (np.array([4., 5., 6.]), 509 np.array([7., 8., 9.])), 510 np.array([10, 11, 12], dtype=np.int64)) 511 512 dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices) 513 self.assertEqual( 514 (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), 515 dataset_ops.get_legacy_output_types(dataset)) 516 self.assertEqual(([], ([], []), []), 517 dataset_ops.get_legacy_output_shapes(dataset)) 518 519 @combinations.generate(test_base.default_test_combinations()) 520 def testNoneComponent(self): 521 dataset = dataset_ops.Dataset.from_tensors((42, None)) 522 if context.executing_eagerly(): 523 self.assertDatasetProduces(dataset, expected_output=[(42, None)]) 524 else: 525 iterator = dataset_ops.make_one_shot_iterator(dataset) 526 next_first, next_second = iterator.get_next() 527 self.assertEqual(next_second, None) 528 with self.cached_session() as sess: 529 self.assertEqual(sess.run(next_first), 42) 530 531 @combinations.generate(test_base.default_test_combinations()) 532 def testNoneComponentInFunction(self): 533 534 @def_function.function 535 def fn(ds): 536 total = 0 537 it = iter(ds) 538 for elem in it: 539 x, _ = elem 540 total += x 541 return total 542 543 dataset = dataset_ops.Dataset.range( 544 10, output_type=dtypes.int32).map(lambda x: (x, None)) 545 self.assertEqual(self.evaluate(fn(dataset)), 45) 546 547 @combinations.generate(test_base.default_test_combinations()) 548 def testIncorrectPythonStructure(self): 549 # Tests that an exception is raised (as opposed to a segfault) when the 550 # Python structure assigned to a dataset is incorrect. 551 dataset = dataset_ops.Dataset.range(10) 552 spec = tensor_spec.TensorSpec([], dtypes.int64) 553 new_structure = (spec, spec) 554 dataset = dataset_ops._RestructuredDataset(dataset, new_structure) 555 dataset = dataset.map(lambda x, y: y) 556 557 with self.assertRaisesOpError(""): 558 self.getDatasetOutput(dataset) 559 560 def testNamedTupleStructure(self): 561 Foo = collections.namedtuple("Foo", ["a", "b"]) 562 x = Foo(a=3, b="test") 563 dataset = dataset_ops.Dataset.from_tensors(x) 564 dataset = dataset_ops.Dataset.from_tensor_slices([dataset, dataset]) 565 self.assertEqual( 566 str(dataset.element_spec), 567 "DatasetSpec(Foo(a=TensorSpec(shape=(), dtype=tf.int32, name=None), " 568 "b=TensorSpec(shape=(), dtype=tf.string, name=None)), TensorShape([]))") 569 570 571if __name__ == "__main__": 572 test.main() 573