1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base class for testing serializable datasets.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23import numpy as np 24 25from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops 26from tensorflow.python.data.ops import iterator_ops 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.ops import lookup_ops 32from tensorflow.python.ops import variables 33from tensorflow.python.platform import gfile 34from tensorflow.python.platform import test 35from tensorflow.python.training import saver as saver_lib 36from tensorflow.python.util import nest 37 38 39def remove_variants(get_next_op): 40 # TODO(b/72408568): Remove this once session.run can get 41 # variant tensors. 42 """Remove variants from a nest structure, so sess.run will execute.""" 43 44 def _remove_variant(x): 45 if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant: 46 return () 47 else: 48 return x 49 50 return nest.map_structure(_remove_variant, get_next_op) 51 52 53class DatasetSerializationTestBase(test.TestCase): 54 """Base class for testing serializable datasets.""" 55 56 def tearDown(self): 57 self._delete_ckpt() 58 59 # TODO(b/72657739): Remove sparse_tensor argument, which is to test the 60 # (deprecated) saveable `SparseTensorSliceDataset`, once the API 61 # `from_sparse_tensor_slices()`and related tests are deleted. 62 def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False): 63 """Runs the core tests. 64 65 Args: 66 ds_fn1: 0-argument function that returns a Dataset. 67 ds_fn2: 0-argument function that returns a Dataset different from 68 ds_fn1. If None, verify_restore_in_modified_graph test is not run. 69 num_outputs: Total number of outputs expected from this Dataset. 70 sparse_tensors: Whether dataset is built from SparseTensor(s). 71 72 Raises: 73 AssertionError if any test fails. 74 """ 75 self.verify_unused_iterator( 76 ds_fn1, num_outputs, sparse_tensors=sparse_tensors) 77 self.verify_fully_used_iterator( 78 ds_fn1, num_outputs, sparse_tensors=sparse_tensors) 79 self.verify_exhausted_iterator( 80 ds_fn1, num_outputs, sparse_tensors=sparse_tensors) 81 self.verify_init_before_restore( 82 ds_fn1, num_outputs, sparse_tensors=sparse_tensors) 83 self.verify_multiple_breaks( 84 ds_fn1, num_outputs, sparse_tensors=sparse_tensors) 85 self.verify_reset_restored_iterator( 86 ds_fn1, num_outputs, sparse_tensors=sparse_tensors) 87 self.verify_restore_in_empty_graph( 88 ds_fn1, num_outputs, sparse_tensors=sparse_tensors) 89 if ds_fn2: 90 self.verify_restore_in_modified_graph( 91 ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors) 92 93 def verify_unused_iterator(self, 94 ds_fn, 95 num_outputs, 96 sparse_tensors=False, 97 verify_exhausted=True): 98 """Verifies that saving and restoring an unused iterator works. 99 100 Args: 101 ds_fn: See `run_core_tests`. 102 num_outputs: See `run_core_tests`. 103 sparse_tensors: See `run_core_tests`. 104 verify_exhausted: See `gen_outputs`. 105 106 Raises: 107 AssertionError if any test fails. 108 """ 109 self.verify_run_with_breaks( 110 ds_fn, [0], 111 num_outputs, 112 sparse_tensors=sparse_tensors, 113 verify_exhausted=verify_exhausted) 114 115 def verify_fully_used_iterator(self, ds_fn, num_outputs, 116 sparse_tensors=False): 117 """Verifies that saving and restoring a fully used iterator works. 118 119 Note that this only checks saving and restoring an iterator from which 120 `num_outputs` items have been produced but does not check for an 121 exhausted iterator, i.e., one from which an OutOfRange error has been 122 returned. 123 124 Args: 125 ds_fn: See `run_core_tests`. 126 num_outputs: See `run_core_tests`. 127 sparse_tensors: See `run_core_tests`. 128 129 Raises: 130 AssertionError if test fails. 131 """ 132 self.verify_run_with_breaks( 133 ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors) 134 135 def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False): 136 """Verifies that saving and restoring an exhausted iterator works. 137 138 An exhausted iterator is one which has returned an OutOfRange error. 139 140 Args: 141 ds_fn: See `run_core_tests`. 142 num_outputs: See `run_core_tests`. 143 sparse_tensors: See `run_core_tests`. 144 145 Raises: 146 AssertionError if any test fails. 147 """ 148 self.gen_outputs( 149 ds_fn, [], 150 num_outputs, 151 verify_exhausted=True, 152 sparse_tensors=sparse_tensors) 153 actual = self.gen_outputs( 154 ds_fn, [], 155 0, 156 ckpt_saved=True, 157 verify_exhausted=True, 158 sparse_tensors=sparse_tensors) 159 self.assertEqual(len(actual), 0) 160 161 def verify_init_before_restore(self, 162 ds_fn, 163 num_outputs, 164 sparse_tensors=False, 165 verify_exhausted=True): 166 """Verifies that restoring into an already initilized iterator works. 167 168 Args: 169 ds_fn: See `run_core_tests`. 170 num_outputs: See `run_core_tests`. 171 sparse_tensors: See `run_core_tests`. 172 verify_exhausted: See `gen_outputs`. 173 174 Raises: 175 AssertionError if any test fails. 176 """ 177 self.verify_run_with_breaks( 178 ds_fn, 179 self.gen_break_points(num_outputs), 180 num_outputs, 181 init_before_restore=True, 182 sparse_tensors=sparse_tensors, 183 verify_exhausted=verify_exhausted) 184 185 def verify_multiple_breaks(self, 186 ds_fn, 187 num_outputs, 188 num_breaks=10, 189 sparse_tensors=False, 190 verify_exhausted=True): 191 """Attempts to save/restore at multiple break points. 192 193 Args: 194 ds_fn: See `run_core_tests`. 195 num_outputs: See `run_core_tests`. 196 num_breaks: The number of break points. These are uniformly spread in 197 [0, num_outputs] both inclusive. 198 sparse_tensors: See `run_core_tests`. 199 verify_exhausted: See `gen_outputs`. 200 201 Raises: 202 AssertionError if any test fails. 203 """ 204 self.verify_run_with_breaks( 205 ds_fn, 206 self.gen_break_points(num_outputs, num_breaks), 207 num_outputs, 208 sparse_tensors=sparse_tensors, 209 verify_exhausted=verify_exhausted) 210 211 def verify_reset_restored_iterator(self, 212 ds_fn, 213 num_outputs, 214 break_point=None, 215 sparse_tensors=False, 216 verify_exhausted=True): 217 """Attempts to re-initialize a restored iterator. 218 219 This is useful when restoring a training checkpoint during validation. 220 221 Args: 222 ds_fn: See `run_core_tests`. 223 num_outputs: See `run_core_tests`. 224 break_point: Break point. Optional. Defaults to num_outputs/2. 225 sparse_tensors: See `run_core_tests`. 226 verify_exhausted: See `gen_outputs`. 227 228 Raises: 229 AssertionError if any test fails. 230 """ 231 break_point = num_outputs // 2 if not break_point else break_point 232 233 # Collect ground truth containing all outputs. 234 expected = self.gen_outputs( 235 ds_fn, [], 236 num_outputs, 237 sparse_tensors=sparse_tensors, 238 verify_exhausted=verify_exhausted) 239 240 # Skip some items and save checkpoint. 241 self.gen_outputs( 242 ds_fn, [], 243 break_point, 244 sparse_tensors=sparse_tensors, 245 verify_exhausted=False) 246 247 actual = [] 248 # Restore from checkpoint and then run init_op. 249 with ops.Graph().as_default() as g: 250 saver = self._import_meta_graph() 251 init_op, get_next_op = self._get_iterator_ops_from_collection( 252 ds_fn, sparse_tensors=sparse_tensors) 253 get_next_op = remove_variants(get_next_op) 254 with self.test_session(graph=g) as sess: 255 self._restore(saver, sess) 256 self._initialize(init_op, sess) 257 for _ in range(num_outputs): 258 actual.append(sess.run(get_next_op)) 259 if verify_exhausted: 260 with self.assertRaises(errors.OutOfRangeError): 261 sess.run(get_next_op) 262 self.match(expected, actual) 263 264 def verify_restore_in_modified_graph(self, 265 ds_fn1, 266 ds_fn2, 267 num_outputs, 268 break_point=None, 269 sparse_tensors=False, 270 verify_exhausted=True): 271 """Attempts to restore an iterator in a modified graph. 272 273 Builds an input pipeline using ds_fn1, runs it for `break_point` steps 274 and saves a checkpoint. Then builds a new graph using ds_fn2, restores 275 the checkpoint from ds_fn1 and verifies that the restore is successful. 276 277 Args: 278 ds_fn1: See `run_core_tests`. 279 ds_fn2: See `run_core_tests`. 280 num_outputs: See `run_core_tests`. 281 break_point: Break point. Optional. Defaults to num_outputs/2. 282 sparse_tensors: See `run_core_tests`. 283 verify_exhausted: See `gen_outputs`. 284 285 Raises: 286 AssertionError if any test fails. 287 """ 288 break_point = num_outputs // 2 if not break_point else break_point 289 290 # Skip `break_point` items and store the remaining produced from ds_fn1 291 # in `expected`. 292 self.gen_outputs( 293 ds_fn1, [], 294 break_point, 295 sparse_tensors=sparse_tensors, 296 verify_exhausted=False) 297 expected = self.gen_outputs( 298 ds_fn1, [], 299 num_outputs - break_point, 300 ckpt_saved=True, 301 sparse_tensors=sparse_tensors, 302 verify_exhausted=verify_exhausted) 303 304 # Generate `break_point` items from ds_fn1 and save checkpoint. 305 self.gen_outputs( 306 ds_fn1, [], 307 break_point, 308 sparse_tensors=sparse_tensors, 309 verify_exhausted=False) 310 311 actual = [] 312 # Build graph for ds_fn2 but load checkpoint for ds_fn1. 313 with ops.Graph().as_default() as g: 314 _, get_next_op, saver = self._build_graph( 315 ds_fn2, sparse_tensors=sparse_tensors) 316 get_next_op = remove_variants(get_next_op) 317 with self.test_session(graph=g) as sess: 318 self._restore(saver, sess) 319 for _ in range(num_outputs - break_point): 320 actual.append(sess.run(get_next_op)) 321 if verify_exhausted: 322 with self.assertRaises(errors.OutOfRangeError): 323 sess.run(get_next_op) 324 325 self.match(expected, actual) 326 327 def verify_restore_in_empty_graph(self, 328 ds_fn, 329 num_outputs, 330 break_point=None, 331 sparse_tensors=False, 332 verify_exhausted=True): 333 """Attempts to restore an iterator in an empty graph. 334 335 Builds an input pipeline using ds_fn, runs it for `break_point` steps 336 and saves a checkpoint. Then builds a new empty graph, restores 337 the checkpoint from ds_fn and verifies that the restore is successful. 338 339 Args: 340 ds_fn: See `run_core_tests`. 341 num_outputs: See `run_core_tests`. 342 break_point: Break point. Optional. Defaults to num_outputs/2. 343 sparse_tensors: See `run_core_tests`. 344 verify_exhausted: See `gen_outputs`. 345 346 Raises: 347 AssertionError if any test fails. 348 """ 349 break_point = num_outputs // 2 if not break_point else break_point 350 351 # Skip `break_point` items and store the remaining produced from ds_fn 352 # in `expected`. 353 self.gen_outputs( 354 ds_fn, [], 355 break_point, 356 sparse_tensors=sparse_tensors, 357 verify_exhausted=False) 358 expected = self.gen_outputs( 359 ds_fn, [], 360 num_outputs - break_point, 361 ckpt_saved=True, 362 sparse_tensors=sparse_tensors, 363 verify_exhausted=verify_exhausted) 364 365 # Generate `break_point` items from ds_fn and save checkpoint. 366 self.gen_outputs( 367 ds_fn, [], 368 break_point, 369 sparse_tensors=sparse_tensors, 370 verify_exhausted=False) 371 372 actual = [] 373 # Build an empty graph but load checkpoint for ds_fn. 374 with ops.Graph().as_default() as g: 375 get_next_op, saver = self._build_empty_graph( 376 ds_fn, sparse_tensors=sparse_tensors) 377 get_next_op = remove_variants(get_next_op) 378 with self.test_session(graph=g) as sess: 379 self._restore(saver, sess) 380 for _ in range(num_outputs - break_point): 381 actual.append(sess.run(get_next_op)) 382 if verify_exhausted: 383 with self.assertRaises(errors.OutOfRangeError): 384 sess.run(get_next_op) 385 386 self.match(expected, actual) 387 388 def verify_error_on_save(self, 389 ds_fn, 390 num_outputs, 391 error, 392 break_point=None, 393 sparse_tensors=False): 394 """Attempts to save a non-saveable iterator. 395 396 Args: 397 ds_fn: See `run_core_tests`. 398 num_outputs: See `run_core_tests`. 399 error: Declared error when trying to save iterator. 400 break_point: Break point. Optional. Defaults to num_outputs/2. 401 sparse_tensors: See `run_core_tests`. 402 403 Raises: 404 AssertionError if any test fails. 405 """ 406 407 break_point = num_outputs // 2 if not break_point else break_point 408 with ops.Graph().as_default() as g: 409 init_op, get_next_op, saver = self._build_graph( 410 ds_fn, sparse_tensors=sparse_tensors) 411 get_next_op = remove_variants(get_next_op) 412 with self.test_session(graph=g) as sess: 413 self._initialize(init_op, sess) 414 for _ in range(break_point): 415 sess.run(get_next_op) 416 with self.assertRaises(error): 417 self._save(sess, saver) 418 419 def verify_run_with_breaks(self, 420 ds_fn, 421 break_points, 422 num_outputs, 423 init_before_restore=False, 424 sparse_tensors=False, 425 verify_exhausted=True): 426 """Verifies that ds_fn() produces the same outputs with and without breaks. 427 428 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it 429 *without* stopping at break points. 430 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it 431 with stopping at break points. 432 433 Deep matches outputs from 1 and 2. 434 435 Args: 436 ds_fn: See `gen_outputs`. 437 break_points: See `gen_outputs`. 438 num_outputs: See `gen_outputs`. 439 init_before_restore: See `gen_outputs`. 440 sparse_tensors: See `run_core_tests`. 441 verify_exhausted: See `gen_outputs`. 442 443 Raises: 444 AssertionError if any test fails. 445 """ 446 expected = self.gen_outputs( 447 ds_fn, [], 448 num_outputs, 449 init_before_restore=init_before_restore, 450 sparse_tensors=sparse_tensors, 451 verify_exhausted=verify_exhausted) 452 453 actual = self.gen_outputs( 454 ds_fn, 455 break_points, 456 num_outputs, 457 init_before_restore=init_before_restore, 458 sparse_tensors=sparse_tensors, 459 verify_exhausted=verify_exhausted) 460 461 self.match(expected, actual) 462 463 def gen_outputs(self, 464 ds_fn, 465 break_points, 466 num_outputs, 467 ckpt_saved=False, 468 init_before_restore=False, 469 sparse_tensors=False, 470 verify_exhausted=True): 471 """Generates elements from input dataset while stopping at break points. 472 473 Produces `num_outputs` outputs and saves the state of the iterator in the 474 Saver checkpoint. 475 476 Args: 477 ds_fn: 0-argument function that returns the dataset. 478 break_points: A list of integers. For each `break_point` in 479 `break_points`, we produce outputs till `break_point` number of items 480 have been produced and then checkpoint the state. The current graph 481 and session are destroyed and a new graph and session are used to 482 produce outputs till next checkpoint or till `num_outputs` elements 483 have been produced. `break_point` must be <= `num_outputs`. 484 num_outputs: The total number of outputs to produce from the iterator. 485 ckpt_saved: Whether a checkpoint already exists. If False, we build the 486 graph from ds_fn. 487 init_before_restore: Whether init should be called before saver.restore. 488 This is just so that we can verify that restoring an already initialized 489 iterator works. 490 sparse_tensors: Whether dataset is built from SparseTensor(s). 491 verify_exhausted: Whether to verify that the iterator has been exhausted 492 after producing `num_outputs` elements. 493 494 Returns: 495 A list of `num_outputs` items. 496 """ 497 outputs = [] 498 499 def get_ops(): 500 if ckpt_saved: 501 saver = self._import_meta_graph() 502 init_op, get_next_op = self._get_iterator_ops_from_collection( 503 ds_fn, sparse_tensors=sparse_tensors) 504 else: 505 init_op, get_next_op, saver = self._build_graph( 506 ds_fn, sparse_tensors=sparse_tensors) 507 get_next_op = remove_variants(get_next_op) 508 return init_op, get_next_op, saver 509 510 for i in range(len(break_points) + 1): 511 with ops.Graph().as_default() as g: 512 init_op, get_next_op, saver = get_ops() 513 get_next_op = remove_variants(get_next_op) 514 with self.test_session(graph=g) as sess: 515 if ckpt_saved: 516 if init_before_restore: 517 self._initialize(init_op, sess) 518 self._restore(saver, sess) 519 else: 520 self._initialize(init_op, sess) 521 start = break_points[i - 1] if i > 0 else 0 522 end = break_points[i] if i < len(break_points) else num_outputs 523 num_iters = end - start 524 for _ in range(num_iters): 525 outputs.append(sess.run(get_next_op)) 526 if i == len(break_points) and verify_exhausted: 527 with self.assertRaises(errors.OutOfRangeError): 528 sess.run(get_next_op) 529 self._save(sess, saver) 530 ckpt_saved = True 531 532 return outputs 533 534 def match(self, expected, actual): 535 """Matches nested structures. 536 537 Recursively matches shape and values of `expected` and `actual`. 538 Handles scalars, numpy arrays and other python sequence containers 539 e.g. list, dict. 540 541 Args: 542 expected: Nested structure 1. 543 actual: Nested structure 2. 544 545 Raises: 546 AssertionError if matching fails. 547 """ 548 if isinstance(expected, np.ndarray): 549 expected = expected.tolist() 550 if isinstance(actual, np.ndarray): 551 actual = actual.tolist() 552 self.assertEqual(type(expected), type(actual)) 553 554 if nest.is_sequence(expected): 555 self.assertEqual(len(expected), len(actual)) 556 if isinstance(expected, dict): 557 for key1, key2 in zip(sorted(expected), sorted(actual)): 558 self.assertEqual(key1, key2) 559 self.match(expected[key1], actual[key2]) 560 else: 561 for item1, item2 in zip(expected, actual): 562 self.match(item1, item2) 563 else: 564 self.assertEqual(expected, actual) 565 566 def does_not_match(self, expected, actual): 567 with self.assertRaises(AssertionError): 568 self.match(expected, actual) 569 570 def gen_break_points(self, num_outputs, num_samples=10): 571 """Generates `num_samples` breaks points in [0, num_outputs].""" 572 return np.linspace(0, num_outputs, num_samples, dtype=int) 573 574 def _build_graph(self, ds_fn, sparse_tensors=False): 575 iterator = ds_fn().make_initializable_iterator() 576 577 saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) 578 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 579 init_op = iterator.initializer 580 if sparse_tensors: 581 get_next = sparse_tensor.SparseTensor(*iterator.get_next()) 582 else: 583 get_next = iterator.get_next() 584 self._add_iterator_ops_to_collection(init_op, get_next, ds_fn, 585 sparse_tensors) 586 saver = saver_lib.Saver(allow_empty=True) 587 return init_op, get_next, saver 588 589 def _build_empty_graph(self, ds_fn, sparse_tensors=False): 590 iterator = iterator_ops.Iterator.from_structure( 591 self._get_output_types(ds_fn), 592 output_shapes=self._get_output_shapes(ds_fn), 593 output_classes=self._get_output_classes(ds_fn)) 594 saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) 595 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 596 if sparse_tensors: 597 get_next = sparse_tensor.SparseTensor(*iterator.get_next()) 598 else: 599 get_next = iterator.get_next() 600 saver = saver_lib.Saver(allow_empty=True) 601 return get_next, saver 602 603 def _add_iterator_ops_to_collection(self, 604 init_op, 605 get_next, 606 ds_fn, 607 sparse_tensors=False): 608 ops.add_to_collection("iterator_ops", init_op) 609 # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections 610 # do not support tuples we flatten the tensors and restore the shape in 611 # `_get_iterator_ops_from_collection`. 612 613 # TODO(shivaniagrwal): `output_classes` is a nested structure of classes, 614 # this base class is specific to current test cases. Update when tests are 615 # added with `output_classes` as a nested structure with at least one of the 616 # component being `tf.SparseTensor`. 617 if (sparse_tensors or 618 self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor): 619 ops.add_to_collection("iterator_ops", get_next.indices) 620 ops.add_to_collection("iterator_ops", get_next.values) 621 ops.add_to_collection("iterator_ops", get_next.dense_shape) 622 else: 623 for el in nest.flatten(get_next): 624 ops.add_to_collection("iterator_ops", el) 625 626 def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): 627 all_ops = ops.get_collection("iterator_ops") 628 if (sparse_tensors or 629 self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor): 630 init_op, indices, values, dense_shape = all_ops 631 return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) 632 else: 633 return all_ops[0], nest.pack_sequence_as( 634 self._get_output_types(ds_fn), all_ops[1:]) 635 636 def _get_output_types(self, ds_fn): 637 with ops.Graph().as_default(): 638 return ds_fn().output_types 639 640 def _get_output_shapes(self, ds_fn): 641 with ops.Graph().as_default(): 642 return ds_fn().output_shapes 643 644 def _get_output_classes(self, ds_fn): 645 with ops.Graph().as_default(): 646 return ds_fn().output_classes 647 648 def _ckpt_path(self): 649 return os.path.join(self.get_temp_dir(), "iterator") 650 651 def _latest_ckpt(self): 652 return saver_lib.latest_checkpoint(self.get_temp_dir()) 653 654 def _save(self, sess, saver): 655 saver.save(sess, self._ckpt_path()) 656 657 def _restore(self, saver, sess): 658 sess.run(lookup_ops.tables_initializer()) 659 saver.restore(sess, self._latest_ckpt()) 660 661 def _initialize(self, init_op, sess): 662 sess.run(variables.global_variables_initializer()) 663 sess.run(lookup_ops.tables_initializer()) 664 sess.run(init_op) 665 666 def _import_meta_graph(self): 667 meta_file_path = self._ckpt_path() + ".meta" 668 return saver_lib.import_meta_graph(meta_file_path) 669 670 def _delete_ckpt(self): 671 # Remove all checkpoint files. 672 prefix = self._ckpt_path() 673 pattern = prefix + "*" 674 files = gfile.Glob(pattern) 675 map(gfile.Remove, files) 676