1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base class for testing serializable datasets.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23import numpy as np 24 25from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.data.ops import iterator_ops 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import errors 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.ops import lookup_ops 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import gfile 35from tensorflow.python.platform import test 36from tensorflow.python.training import checkpoint_management 37from tensorflow.python.training import saver as saver_lib 38from tensorflow.python.util import nest 39 40 41def remove_variants(get_next_op): 42 # TODO(b/72408568): Remove this once session.run can get 43 # variant tensors. 44 """Remove variants from a nest structure, so sess.run will execute.""" 45 46 def _remove_variant(x): 47 if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant: 48 return () 49 else: 50 return x 51 52 return nest.map_structure(_remove_variant, get_next_op) 53 54 55class DatasetSerializationTestBase(test.TestCase): 56 """Base class for testing serializable datasets.""" 57 58 def tearDown(self): 59 self._delete_ckpt() 60 61 # TODO(b/72657739): Remove sparse_tensor argument, which is to test the 62 # (deprecated) saveable `SparseTensorSliceDataset`, once the API 63 # `from_sparse_tensor_slices()`and related tests are deleted. 64 def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False): 65 """Runs the core tests. 66 67 Args: 68 ds_fn1: 0-argument function that returns a Dataset. 69 ds_fn2: 0-argument function that returns a Dataset different from 70 ds_fn1. If None, verify_restore_in_modified_graph test is not run. 71 num_outputs: Total number of outputs expected from this Dataset. 72 sparse_tensors: Whether dataset is built from SparseTensor(s). 73 74 Raises: 75 AssertionError if any test fails. 76 """ 77 # NOTE: We disable all default optimizations in serialization tests in order 78 # to test the actual dataset in question. 79 options = dataset_ops.Options() 80 options.experimental_optimization.apply_default_optimizations = False 81 82 def ds_fn1_no_opt(): 83 return ds_fn1().with_options(options) 84 85 self.verify_unused_iterator( 86 ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors) 87 self.verify_fully_used_iterator( 88 ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors) 89 self.verify_exhausted_iterator( 90 ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors) 91 self.verify_init_before_restore( 92 ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors) 93 self.verify_multiple_breaks( 94 ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors) 95 self.verify_reset_restored_iterator( 96 ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors) 97 self.verify_restore_in_empty_graph( 98 ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors) 99 if ds_fn2: 100 101 def ds_fn2_no_opt(): 102 return ds_fn2().with_options(options) 103 104 self.verify_restore_in_modified_graph( 105 ds_fn1_no_opt, 106 ds_fn2_no_opt, 107 num_outputs, 108 sparse_tensors=sparse_tensors) 109 110 def verify_unused_iterator(self, 111 ds_fn, 112 num_outputs, 113 sparse_tensors=False, 114 verify_exhausted=True): 115 """Verifies that saving and restoring an unused iterator works. 116 117 Args: 118 ds_fn: See `run_core_tests`. 119 num_outputs: See `run_core_tests`. 120 sparse_tensors: See `run_core_tests`. 121 verify_exhausted: See `gen_outputs`. 122 123 Raises: 124 AssertionError if any test fails. 125 """ 126 self.verify_run_with_breaks( 127 ds_fn, [0], 128 num_outputs, 129 sparse_tensors=sparse_tensors, 130 verify_exhausted=verify_exhausted) 131 132 def verify_fully_used_iterator(self, ds_fn, num_outputs, 133 sparse_tensors=False): 134 """Verifies that saving and restoring a fully used iterator works. 135 136 Note that this only checks saving and restoring an iterator from which 137 `num_outputs` items have been produced but does not check for an 138 exhausted iterator, i.e., one from which an OutOfRange error has been 139 returned. 140 141 Args: 142 ds_fn: See `run_core_tests`. 143 num_outputs: See `run_core_tests`. 144 sparse_tensors: See `run_core_tests`. 145 146 Raises: 147 AssertionError if test fails. 148 """ 149 self.verify_run_with_breaks( 150 ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors) 151 152 def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False): 153 """Verifies that saving and restoring an exhausted iterator works. 154 155 An exhausted iterator is one which has returned an OutOfRange error. 156 157 Args: 158 ds_fn: See `run_core_tests`. 159 num_outputs: See `run_core_tests`. 160 sparse_tensors: See `run_core_tests`. 161 162 Raises: 163 AssertionError if any test fails. 164 """ 165 self.gen_outputs( 166 ds_fn, [], 167 num_outputs, 168 verify_exhausted=True, 169 sparse_tensors=sparse_tensors) 170 actual = self.gen_outputs( 171 ds_fn, [], 172 0, 173 ckpt_saved=True, 174 verify_exhausted=True, 175 sparse_tensors=sparse_tensors) 176 self.assertEqual(len(actual), 0) 177 178 def verify_init_before_restore(self, 179 ds_fn, 180 num_outputs, 181 sparse_tensors=False, 182 verify_exhausted=True): 183 """Verifies that restoring into an already initialized iterator works. 184 185 Args: 186 ds_fn: See `run_core_tests`. 187 num_outputs: See `run_core_tests`. 188 sparse_tensors: See `run_core_tests`. 189 verify_exhausted: See `gen_outputs`. 190 191 Raises: 192 AssertionError if any test fails. 193 """ 194 self.verify_run_with_breaks( 195 ds_fn, 196 self.gen_break_points(num_outputs), 197 num_outputs, 198 init_before_restore=True, 199 sparse_tensors=sparse_tensors, 200 verify_exhausted=verify_exhausted) 201 202 def verify_multiple_breaks(self, 203 ds_fn, 204 num_outputs, 205 num_breaks=10, 206 sparse_tensors=False, 207 verify_exhausted=True): 208 """Attempts to save/restore at multiple break points. 209 210 Args: 211 ds_fn: See `run_core_tests`. 212 num_outputs: See `run_core_tests`. 213 num_breaks: The number of break points. These are uniformly spread in 214 [0, num_outputs] both inclusive. 215 sparse_tensors: See `run_core_tests`. 216 verify_exhausted: See `gen_outputs`. 217 218 Raises: 219 AssertionError if any test fails. 220 """ 221 self.verify_run_with_breaks( 222 ds_fn, 223 self.gen_break_points(num_outputs, num_breaks), 224 num_outputs, 225 sparse_tensors=sparse_tensors, 226 verify_exhausted=verify_exhausted) 227 228 def verify_reset_restored_iterator(self, 229 ds_fn, 230 num_outputs, 231 break_point=None, 232 sparse_tensors=False, 233 verify_exhausted=True): 234 """Attempts to re-initialize a restored iterator. 235 236 This is useful when restoring a training checkpoint during validation. 237 238 Args: 239 ds_fn: See `run_core_tests`. 240 num_outputs: See `run_core_tests`. 241 break_point: Break point. Optional. Defaults to num_outputs/2. 242 sparse_tensors: See `run_core_tests`. 243 verify_exhausted: See `gen_outputs`. 244 245 Raises: 246 AssertionError if any test fails. 247 """ 248 break_point = num_outputs // 2 if not break_point else break_point 249 250 # Collect ground truth containing all outputs. 251 expected = self.gen_outputs( 252 ds_fn, [], 253 num_outputs, 254 sparse_tensors=sparse_tensors, 255 verify_exhausted=verify_exhausted) 256 257 # Skip some items and save checkpoint. 258 self.gen_outputs( 259 ds_fn, [], 260 break_point, 261 sparse_tensors=sparse_tensors, 262 verify_exhausted=False) 263 264 actual = [] 265 # Restore from checkpoint and then run init_op. 266 with ops.Graph().as_default() as g: 267 saver = self._import_meta_graph() 268 init_op, get_next_op = self._get_iterator_ops_from_collection( 269 ds_fn, sparse_tensors=sparse_tensors) 270 get_next_op = remove_variants(get_next_op) 271 with self.session(graph=g) as sess: 272 self._restore(saver, sess) 273 self._initialize(init_op, sess) 274 for _ in range(num_outputs): 275 actual.append(sess.run(get_next_op)) 276 if verify_exhausted: 277 with self.assertRaises(errors.OutOfRangeError): 278 sess.run(get_next_op) 279 self.match(expected, actual) 280 281 def verify_restore_in_modified_graph(self, 282 ds_fn1, 283 ds_fn2, 284 num_outputs, 285 break_point=None, 286 sparse_tensors=False, 287 verify_exhausted=True): 288 """Attempts to restore an iterator in a modified graph. 289 290 Builds an input pipeline using ds_fn1, runs it for `break_point` steps 291 and saves a checkpoint. Then builds a new graph using ds_fn2, restores 292 the checkpoint from ds_fn1 and verifies that the restore is successful. 293 294 Args: 295 ds_fn1: See `run_core_tests`. 296 ds_fn2: See `run_core_tests`. 297 num_outputs: See `run_core_tests`. 298 break_point: Break point. Optional. Defaults to num_outputs/2. 299 sparse_tensors: See `run_core_tests`. 300 verify_exhausted: See `gen_outputs`. 301 302 Raises: 303 AssertionError if any test fails. 304 """ 305 break_point = num_outputs // 2 if not break_point else break_point 306 307 # Skip `break_point` items and store the remaining produced from ds_fn1 308 # in `expected`. 309 self.gen_outputs( 310 ds_fn1, [], 311 break_point, 312 sparse_tensors=sparse_tensors, 313 verify_exhausted=False) 314 expected = self.gen_outputs( 315 ds_fn1, [], 316 num_outputs - break_point, 317 ckpt_saved=True, 318 sparse_tensors=sparse_tensors, 319 verify_exhausted=verify_exhausted) 320 321 # Generate `break_point` items from ds_fn1 and save checkpoint. 322 self.gen_outputs( 323 ds_fn1, [], 324 break_point, 325 sparse_tensors=sparse_tensors, 326 verify_exhausted=False) 327 328 actual = [] 329 # Build graph for ds_fn2 but load checkpoint for ds_fn1. 330 with ops.Graph().as_default() as g: 331 _, get_next_op, saver = self._build_graph( 332 ds_fn2, sparse_tensors=sparse_tensors) 333 get_next_op = remove_variants(get_next_op) 334 with self.session(graph=g) as sess: 335 self._restore(saver, sess) 336 for _ in range(num_outputs - break_point): 337 actual.append(sess.run(get_next_op)) 338 if verify_exhausted: 339 with self.assertRaises(errors.OutOfRangeError): 340 sess.run(get_next_op) 341 342 self.match(expected, actual) 343 344 def verify_restore_in_empty_graph(self, 345 ds_fn, 346 num_outputs, 347 break_point=None, 348 sparse_tensors=False, 349 verify_exhausted=True): 350 """Attempts to restore an iterator in an empty graph. 351 352 Builds an input pipeline using ds_fn, runs it for `break_point` steps 353 and saves a checkpoint. Then builds a new empty graph, restores 354 the checkpoint from ds_fn and verifies that the restore is successful. 355 356 Args: 357 ds_fn: See `run_core_tests`. 358 num_outputs: See `run_core_tests`. 359 break_point: Break point. Optional. Defaults to num_outputs/2. 360 sparse_tensors: See `run_core_tests`. 361 verify_exhausted: See `gen_outputs`. 362 363 Raises: 364 AssertionError if any test fails. 365 """ 366 break_point = num_outputs // 2 if not break_point else break_point 367 368 # Skip `break_point` items and store the remaining produced from ds_fn 369 # in `expected`. 370 self.gen_outputs( 371 ds_fn, [], 372 break_point, 373 sparse_tensors=sparse_tensors, 374 verify_exhausted=False) 375 expected = self.gen_outputs( 376 ds_fn, [], 377 num_outputs - break_point, 378 ckpt_saved=True, 379 sparse_tensors=sparse_tensors, 380 verify_exhausted=verify_exhausted) 381 382 # Generate `break_point` items from ds_fn and save checkpoint. 383 self.gen_outputs( 384 ds_fn, [], 385 break_point, 386 sparse_tensors=sparse_tensors, 387 verify_exhausted=False) 388 389 actual = [] 390 # Build an empty graph but load checkpoint for ds_fn. 391 with ops.Graph().as_default() as g: 392 get_next_op, saver = self._build_empty_graph( 393 ds_fn, sparse_tensors=sparse_tensors) 394 get_next_op = remove_variants(get_next_op) 395 with self.session(graph=g) as sess: 396 self._restore(saver, sess) 397 for _ in range(num_outputs - break_point): 398 actual.append(sess.run(get_next_op)) 399 if verify_exhausted: 400 with self.assertRaises(errors.OutOfRangeError): 401 sess.run(get_next_op) 402 403 self.match(expected, actual) 404 405 def verify_error_on_save(self, 406 ds_fn, 407 num_outputs, 408 error, 409 break_point=None, 410 sparse_tensors=False): 411 """Attempts to save a non-saveable iterator. 412 413 Args: 414 ds_fn: See `run_core_tests`. 415 num_outputs: See `run_core_tests`. 416 error: Declared error when trying to save iterator. 417 break_point: Break point. Optional. Defaults to num_outputs/2. 418 sparse_tensors: See `run_core_tests`. 419 420 Raises: 421 AssertionError if any test fails. 422 """ 423 424 break_point = num_outputs // 2 if not break_point else break_point 425 with ops.Graph().as_default() as g: 426 init_op, get_next_op, saver = self._build_graph( 427 ds_fn, sparse_tensors=sparse_tensors) 428 get_next_op = remove_variants(get_next_op) 429 with self.session(graph=g) as sess: 430 self._initialize(init_op, sess) 431 for _ in range(break_point): 432 sess.run(get_next_op) 433 with self.assertRaises(error): 434 self._save(sess, saver) 435 436 def verify_run_with_breaks(self, 437 ds_fn, 438 break_points, 439 num_outputs, 440 init_before_restore=False, 441 sparse_tensors=False, 442 verify_exhausted=True): 443 """Verifies that ds_fn() produces the same outputs with and without breaks. 444 445 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it 446 *without* stopping at break points. 447 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it 448 with stopping at break points. 449 450 Deep matches outputs from 1 and 2. 451 452 Args: 453 ds_fn: See `gen_outputs`. 454 break_points: See `gen_outputs`. 455 num_outputs: See `gen_outputs`. 456 init_before_restore: See `gen_outputs`. 457 sparse_tensors: See `run_core_tests`. 458 verify_exhausted: See `gen_outputs`. 459 460 Raises: 461 AssertionError if any test fails. 462 """ 463 expected = self.gen_outputs( 464 ds_fn, [], 465 num_outputs, 466 init_before_restore=init_before_restore, 467 sparse_tensors=sparse_tensors, 468 verify_exhausted=verify_exhausted) 469 470 actual = self.gen_outputs( 471 ds_fn, 472 break_points, 473 num_outputs, 474 init_before_restore=init_before_restore, 475 sparse_tensors=sparse_tensors, 476 verify_exhausted=verify_exhausted) 477 478 self.match(expected, actual) 479 480 def gen_outputs(self, 481 ds_fn, 482 break_points, 483 num_outputs, 484 ckpt_saved=False, 485 init_before_restore=False, 486 sparse_tensors=False, 487 verify_exhausted=True, 488 save_checkpoint_at_end=True): 489 """Generates elements from input dataset while stopping at break points. 490 491 Produces `num_outputs` outputs and saves the state of the iterator in the 492 Saver checkpoint. 493 494 Args: 495 ds_fn: 0-argument function that returns the dataset. 496 break_points: A list of integers. For each `break_point` in 497 `break_points`, we produce outputs till `break_point` number of items 498 have been produced and then checkpoint the state. The current graph 499 and session are destroyed and a new graph and session are used to 500 produce outputs till next checkpoint or till `num_outputs` elements 501 have been produced. `break_point` must be <= `num_outputs`. 502 num_outputs: The total number of outputs to produce from the iterator. 503 ckpt_saved: Whether a checkpoint already exists. If False, we build the 504 graph from ds_fn. 505 init_before_restore: Whether init should be called before saver.restore. 506 This is just so that we can verify that restoring an already initialized 507 iterator works. 508 sparse_tensors: Whether dataset is built from SparseTensor(s). 509 verify_exhausted: Whether to verify that the iterator has been exhausted 510 after producing `num_outputs` elements. 511 save_checkpoint_at_end: Whether to save a checkpoint after producing all 512 outputs. If False, checkpoints are saved each break point but not at the 513 end. Note that checkpoints overwrite each other so there is always only 514 a single checkpoint available. Defaults to True. 515 516 Returns: 517 A list of `num_outputs` items. 518 """ 519 outputs = [] 520 521 def get_ops(): 522 if ckpt_saved: 523 saver = self._import_meta_graph() 524 init_op, get_next_op = self._get_iterator_ops_from_collection( 525 ds_fn, sparse_tensors=sparse_tensors) 526 else: 527 init_op, get_next_op, saver = self._build_graph( 528 ds_fn, sparse_tensors=sparse_tensors) 529 return init_op, get_next_op, saver 530 531 for i in range(len(break_points) + 1): 532 with ops.Graph().as_default() as g: 533 init_op, get_next_op, saver = get_ops() 534 get_next_op = remove_variants(get_next_op) 535 with self.session(graph=g) as sess: 536 if ckpt_saved: 537 if init_before_restore: 538 self._initialize(init_op, sess) 539 self._restore(saver, sess) 540 else: 541 self._initialize(init_op, sess) 542 start = break_points[i - 1] if i > 0 else 0 543 end = break_points[i] if i < len(break_points) else num_outputs 544 num_iters = end - start 545 for _ in range(num_iters): 546 outputs.append(sess.run(get_next_op)) 547 if i == len(break_points) and verify_exhausted: 548 with self.assertRaises(errors.OutOfRangeError): 549 sess.run(get_next_op) 550 if save_checkpoint_at_end or i < len(break_points): 551 self._save(sess, saver) 552 ckpt_saved = True 553 554 return outputs 555 556 def match(self, expected, actual): 557 """Matches nested structures. 558 559 Recursively matches shape and values of `expected` and `actual`. 560 Handles scalars, numpy arrays and other python sequence containers 561 e.g. list, dict. 562 563 Args: 564 expected: Nested structure 1. 565 actual: Nested structure 2. 566 567 Raises: 568 AssertionError if matching fails. 569 """ 570 if isinstance(expected, np.ndarray): 571 expected = expected.tolist() 572 if isinstance(actual, np.ndarray): 573 actual = actual.tolist() 574 self.assertEqual(type(expected), type(actual)) 575 576 if nest.is_sequence(expected): 577 self.assertEqual(len(expected), len(actual)) 578 if isinstance(expected, dict): 579 for key1, key2 in zip(sorted(expected), sorted(actual)): 580 self.assertEqual(key1, key2) 581 self.match(expected[key1], actual[key2]) 582 else: 583 for item1, item2 in zip(expected, actual): 584 self.match(item1, item2) 585 else: 586 self.assertEqual(expected, actual) 587 588 def does_not_match(self, expected, actual): 589 with self.assertRaises(AssertionError): 590 self.match(expected, actual) 591 592 def gen_break_points(self, num_outputs, num_samples=10): 593 """Generates `num_samples` breaks points in [0, num_outputs].""" 594 return np.linspace(0, num_outputs, num_samples, dtype=int) 595 596 def _build_graph(self, ds_fn, sparse_tensors=False): 597 iterator = dataset_ops.make_initializable_iterator(ds_fn()) 598 599 saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) 600 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 601 init_op = iterator.initializer 602 if sparse_tensors: 603 get_next = sparse_tensor.SparseTensor(*iterator.get_next()) 604 else: 605 get_next = iterator.get_next() 606 self._add_iterator_ops_to_collection(init_op, get_next, ds_fn, 607 sparse_tensors) 608 saver = saver_lib.Saver(allow_empty=True) 609 return init_op, get_next, saver 610 611 def _build_empty_graph(self, ds_fn, sparse_tensors=False): 612 iterator = iterator_ops.Iterator.from_structure( 613 self._get_output_types(ds_fn), 614 output_shapes=self._get_output_shapes(ds_fn), 615 output_classes=self._get_output_classes(ds_fn)) 616 saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) 617 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 618 if sparse_tensors: 619 get_next = sparse_tensor.SparseTensor(*iterator.get_next()) 620 else: 621 get_next = iterator.get_next() 622 saver = saver_lib.Saver(allow_empty=True) 623 return get_next, saver 624 625 def _add_iterator_ops_to_collection(self, 626 init_op, 627 get_next, 628 ds_fn, 629 sparse_tensors=False): 630 ops.add_to_collection("iterator_ops", init_op) 631 # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections 632 # do not support tuples we flatten the tensors and restore the shape in 633 # `_get_iterator_ops_from_collection`. 634 if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. 635 ops.add_to_collection("iterator_ops", get_next.indices) 636 ops.add_to_collection("iterator_ops", get_next.values) 637 ops.add_to_collection("iterator_ops", get_next.dense_shape) 638 return 639 640 get_next_list = nest.flatten(get_next) 641 for i, output_class in enumerate( 642 nest.flatten(self._get_output_classes(ds_fn))): 643 if output_class is sparse_tensor.SparseTensor: 644 ops.add_to_collection("iterator_ops", get_next_list[i].indices) 645 ops.add_to_collection("iterator_ops", get_next_list[i].values) 646 ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape) 647 else: 648 ops.add_to_collection("iterator_ops", get_next_list[i]) 649 650 def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): 651 all_ops = ops.get_collection("iterator_ops") 652 if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. 653 init_op, indices, values, dense_shape = all_ops 654 return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) 655 get_next_list = [] 656 i = 1 657 for output_class in nest.flatten(self._get_output_classes(ds_fn)): 658 if output_class is sparse_tensor.SparseTensor: 659 indices, values, dense_shape = all_ops[i:i + 3] 660 i += 3 661 get_next_list.append( 662 sparse_tensor.SparseTensor(indices, values, dense_shape)) 663 else: 664 get_next_list.append(all_ops[i]) 665 i += 1 666 return all_ops[0], nest.pack_sequence_as( 667 self._get_output_types(ds_fn), get_next_list) 668 669 def _get_output_types(self, ds_fn): 670 with ops.Graph().as_default(): 671 return dataset_ops.get_legacy_output_types(ds_fn()) 672 673 def _get_output_shapes(self, ds_fn): 674 with ops.Graph().as_default(): 675 return dataset_ops.get_legacy_output_shapes(ds_fn()) 676 677 def _get_output_classes(self, ds_fn): 678 with ops.Graph().as_default(): 679 return dataset_ops.get_legacy_output_classes(ds_fn()) 680 681 def _ckpt_path(self): 682 return os.path.join(self.get_temp_dir(), "iterator") 683 684 def _latest_ckpt(self): 685 return checkpoint_management.latest_checkpoint(self.get_temp_dir()) 686 687 def _save(self, sess, saver): 688 saver.save(sess, self._ckpt_path()) 689 690 def _restore(self, saver, sess): 691 sess.run(lookup_ops.tables_initializer()) 692 saver.restore(sess, self._latest_ckpt()) 693 694 def _initialize(self, init_op, sess): 695 sess.run(variables.global_variables_initializer()) 696 sess.run(lookup_ops.tables_initializer()) 697 sess.run(init_op) 698 699 def _import_meta_graph(self): 700 meta_file_path = self._ckpt_path() + ".meta" 701 return saver_lib.import_meta_graph(meta_file_path) 702 703 def _delete_ckpt(self): 704 # Remove all checkpoint files. 705 prefix = self._ckpt_path() 706 pattern = prefix + "*" 707 files = gfile.Glob(pattern) 708 map(gfile.Remove, files) 709