1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base test class for checkpointing datasets.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23import numpy as np 24 25from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.data.ops import options as options_lib 28from tensorflow.python.eager import context 29from tensorflow.python.framework import combinations 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import sparse_tensor 34from tensorflow.python.ops import lookup_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.ops.ragged import ragged_tensor_value 37from tensorflow.python.platform import gfile 38from tensorflow.python.platform import test 39from tensorflow.python.training import checkpoint_management 40from tensorflow.python.training import saver as saver_lib 41from tensorflow.python.training.tracking import util as tracking_util 42from tensorflow.python.util import nest 43 44 45def remove_variants(get_next_op): 46 # TODO(b/72408568): Remove this once session.run can get variant tensors. 47 """Remove variants from a nest structure, so sess.run will execute.""" 48 49 def _remove_variant(x): 50 if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant: 51 return () 52 else: 53 return x 54 55 return nest.map_structure(_remove_variant, get_next_op) 56 57 58def default_test_combinations(): 59 """Returns the default test combinations for testing checkpointing.""" 60 61 def disable_optimizations(ds_fn): 62 options = options_lib.Options() 63 options.experimental_optimization.apply_default_optimizations = False 64 65 def ds_fn_no_opt(): 66 return ds_fn().with_options(options) 67 68 return ds_fn_no_opt 69 70 def verify_unused_iterator(obj, ds_fn, num_outputs, sparse_tensors=False): 71 obj.verify_unused_iterator( 72 ds_fn=disable_optimizations(ds_fn=ds_fn), 73 num_outputs=num_outputs, 74 sparse_tensors=sparse_tensors) 75 76 verify_unused_iterator_combination = combinations.combine( 77 verify_fn=combinations.NamedObject("verify_unused_iterator", 78 verify_unused_iterator)) 79 80 def verify_fully_used_iterator(obj, ds_fn, num_outputs, sparse_tensors=False): 81 obj.verify_fully_used_iterator( 82 ds_fn=disable_optimizations(ds_fn=ds_fn), 83 num_outputs=num_outputs, 84 sparse_tensors=sparse_tensors) 85 86 verify_fully_used_iterator_combination = combinations.combine( 87 verify_fn=combinations.NamedObject("verify_fully_used_iterator", 88 verify_fully_used_iterator)) 89 90 def verify_exhausted_iterator(obj, ds_fn, num_outputs, sparse_tensors=False): 91 obj.verify_exhausted_iterator( 92 ds_fn=disable_optimizations(ds_fn=ds_fn), 93 num_outputs=num_outputs, 94 sparse_tensors=sparse_tensors) 95 96 verify_exhausted_iterator_combination = combinations.combine( 97 verify_fn=combinations.NamedObject("verify_exhausted_iterator", 98 verify_exhausted_iterator)) 99 100 def verify_multiple_breaks(obj, ds_fn, num_outputs, sparse_tensors=False): 101 obj.verify_multiple_breaks( 102 ds_fn=disable_optimizations(ds_fn=ds_fn), 103 num_outputs=num_outputs, 104 sparse_tensors=sparse_tensors) 105 106 verify_multiple_breaks_combination = combinations.combine( 107 verify_fn=combinations.NamedObject("verify_multiple_breaks", 108 verify_multiple_breaks)) 109 110 def verify_reset_restored_iterator(obj, 111 ds_fn, 112 num_outputs, 113 sparse_tensors=False): 114 obj.verify_reset_restored_iterator( 115 ds_fn=disable_optimizations(ds_fn=ds_fn), 116 num_outputs=num_outputs, 117 sparse_tensors=sparse_tensors) 118 119 verify_reset_restored_iterator_combination = combinations.combine( 120 verify_fn=combinations.NamedObject("verify_reset_restored_iterator", 121 verify_reset_restored_iterator)) 122 123 return (verify_unused_iterator_combination + 124 verify_fully_used_iterator_combination + 125 verify_exhausted_iterator_combination + 126 verify_multiple_breaks_combination + 127 verify_reset_restored_iterator_combination) 128 129 130# TODO(b/72657739): Remove sparse_tensor argument, which is to test the 131# (deprecated) saveable `SparseTensorSliceDataset`, once the API 132# `from_sparse_tensor_slices()` and related tests are deleted. 133class CheckpointTestBase(test.TestCase): 134 """Base test class for checkpointing datasets.""" 135 136 def tearDown(self): 137 self._delete_ckpt() 138 super(CheckpointTestBase, self).tearDown() 139 140 def verify_unused_iterator(self, 141 ds_fn, 142 num_outputs, 143 sparse_tensors=False, 144 verify_exhausted=True): 145 """Verifies that saving and restoring an unused iterator works. 146 147 Args: 148 ds_fn: 0-argument function that returns a Dataset. 149 num_outputs: Total number of outputs expected from this Dataset. 150 sparse_tensors: Whether dataset is built from SparseTensor(s). 151 verify_exhausted: Whether to verify that the iterator has been exhausted 152 after producing `num_outputs` elements. 153 154 Raises: 155 AssertionError if any test fails. 156 """ 157 self.verify_run_with_breaks( 158 ds_fn, [0], 159 num_outputs, 160 sparse_tensors=sparse_tensors, 161 verify_exhausted=verify_exhausted) 162 163 def verify_fully_used_iterator(self, 164 ds_fn, 165 num_outputs, 166 sparse_tensors=False): 167 """Verifies that saving and restoring a fully used iterator works. 168 169 Note that this only checks saving and restoring an iterator from which 170 `num_outputs` items have been produced but does not check for an 171 exhausted iterator, i.e., one from which an OutOfRange error has been 172 returned. 173 174 Args: 175 ds_fn: 0-argument function that returns a Dataset. 176 num_outputs: Total number of outputs expected from this Dataset. 177 sparse_tensors: Whether dataset is built from SparseTensor(s). 178 179 Raises: 180 AssertionError if test fails. 181 """ 182 self.verify_run_with_breaks( 183 ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors) 184 185 def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False): 186 """Verifies that saving and restoring an exhausted iterator works. 187 188 An exhausted iterator is one which has returned an OutOfRange error. 189 190 Args: 191 ds_fn: 0-argument function that returns a Dataset. 192 num_outputs: Total number of outputs expected from this Dataset. 193 sparse_tensors: Whether dataset is built from SparseTensor(s). 194 195 Raises: 196 AssertionError if any test fails. 197 """ 198 self.gen_outputs( 199 ds_fn, [], 200 num_outputs, 201 verify_exhausted=True, 202 sparse_tensors=sparse_tensors) 203 actual = self.gen_outputs( 204 ds_fn, [], 205 0, 206 ckpt_saved=True, 207 verify_exhausted=True, 208 sparse_tensors=sparse_tensors) 209 self.assertLen(actual, 0) 210 211 def verify_multiple_breaks(self, 212 ds_fn, 213 num_outputs, 214 num_breaks=10, 215 sparse_tensors=False, 216 verify_exhausted=True): 217 """Attempts to save/restore at multiple break points. 218 219 Args: 220 ds_fn: 0-argument function that returns a Dataset. 221 num_outputs: Total number of outputs expected from this Dataset. 222 num_breaks: The number of break points. These are uniformly spread in [0, 223 num_outputs] both inclusive. 224 sparse_tensors: Whether dataset is built from SparseTensor(s). 225 verify_exhausted: Whether to verify that the iterator has been exhausted 226 after producing `num_outputs` elements. 227 228 Raises: 229 AssertionError if any test fails. 230 """ 231 self.verify_run_with_breaks( 232 ds_fn, 233 self.gen_break_points(num_outputs, num_breaks), 234 num_outputs, 235 sparse_tensors=sparse_tensors, 236 verify_exhausted=verify_exhausted) 237 238 def verify_reset_restored_iterator(self, 239 ds_fn, 240 num_outputs, 241 break_point=None, 242 sparse_tensors=False, 243 verify_exhausted=True): 244 """Attempts to re-initialize a restored iterator. 245 246 This is useful when restoring a training checkpoint during validation. 247 248 Args: 249 ds_fn: 0-argument function that returns a Dataset. 250 num_outputs: Total number of outputs expected from this Dataset. 251 break_point: Break point. Optional. Defaults to num_outputs/2. 252 sparse_tensors: Whether dataset is built from SparseTensor(s). 253 verify_exhausted: Whether to verify that the iterator has been exhausted 254 after producing `num_outputs` elements. 255 256 Raises: 257 AssertionError if any test fails. 258 """ 259 if context.executing_eagerly(): 260 self.skipTest("Eager mode iteration do not support re-initialization.") 261 262 break_point = num_outputs // 2 if not break_point else break_point 263 264 # Collect ground truth containing all outputs. 265 expected = self.gen_outputs( 266 ds_fn, [], 267 num_outputs, 268 sparse_tensors=sparse_tensors, 269 verify_exhausted=verify_exhausted) 270 271 # Skip some items and save checkpoint. 272 self.gen_outputs( 273 ds_fn, [], 274 break_point, 275 sparse_tensors=sparse_tensors, 276 verify_exhausted=False) 277 278 actual = [] 279 # Restore from checkpoint and then run init_op. 280 with ops.Graph().as_default() as g: 281 saver = self._import_meta_graph() 282 init_op, get_next_op = self._get_iterator_ops_from_collection( 283 ds_fn, sparse_tensors=sparse_tensors) 284 get_next_op = remove_variants(get_next_op) 285 with self.session(graph=g) as sess: 286 self._initialize(init_op, sess) 287 self._restore(saver, sess) 288 self._initialize(init_op, sess) 289 for _ in range(num_outputs): 290 actual.append(sess.run(get_next_op)) 291 if verify_exhausted: 292 with self.assertRaises(errors.OutOfRangeError): 293 sess.run(get_next_op) 294 self.match(expected, actual) 295 296 def verify_error_on_save(self, 297 ds_fn, 298 num_outputs, 299 error, 300 break_point=None, 301 sparse_tensors=False): 302 """Attempts to save a non-saveable iterator. 303 304 Args: 305 ds_fn: 0-argument function that returns a Dataset. 306 num_outputs: Total number of outputs expected from this Dataset. 307 error: Declared error when trying to save iterator. 308 break_point: Break point. Optional. Defaults to num_outputs/2. 309 sparse_tensors: Whether dataset is built from SparseTensor(s). 310 311 Raises: 312 AssertionError if any test fails. 313 """ 314 break_point = num_outputs // 2 if not break_point else break_point 315 if context.executing_eagerly(): 316 iterator = iter(ds_fn()) 317 ckpt = tracking_util.Checkpoint(iterator=iterator) 318 for _ in range(break_point): 319 next(iterator) 320 with self.assertRaises(error): 321 ckpt.save(self._ckpt_path()) 322 else: 323 with ops.Graph().as_default() as g: 324 init_op, get_next_op, saver = self._build_graph( 325 ds_fn, sparse_tensors=sparse_tensors) 326 get_next_op = remove_variants(get_next_op) 327 with self.session(graph=g) as sess: 328 self._initialize(init_op, sess) 329 for _ in range(break_point): 330 sess.run(get_next_op) 331 with self.assertRaises(error): 332 self._save(sess, saver) 333 334 def verify_run_with_breaks(self, 335 ds_fn, 336 break_points, 337 num_outputs, 338 sparse_tensors=False, 339 verify_exhausted=True): 340 """Verifies that ds_fn() produces the same outputs with and without breaks. 341 342 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it 343 *without* stopping at break points. 344 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it 345 with stopping at break points. 346 347 Deep matches outputs from 1 and 2. 348 349 Args: 350 ds_fn: 0-argument function that returns a Dataset. 351 break_points: A list of integers. For each `break_point` in 352 `break_points`, we produce outputs till `break_point` number of items 353 have been produced and then checkpoint the state. The current graph and 354 session are destroyed and a new graph and session are used to produce 355 outputs till next checkpoint or till `num_outputs` elements have been 356 produced. `break_point` must be <= `num_outputs`. 357 num_outputs: Total number of outputs expected from this Dataset. 358 sparse_tensors: Whether dataset is built from SparseTensor(s). 359 verify_exhausted: Whether to verify that the iterator has been exhausted 360 after producing `num_outputs` elements. 361 362 Raises: 363 AssertionError if any test fails. 364 """ 365 expected = self.gen_outputs( 366 ds_fn, [], 367 num_outputs, 368 sparse_tensors=sparse_tensors, 369 verify_exhausted=verify_exhausted) 370 371 actual = self.gen_outputs( 372 ds_fn, 373 break_points, 374 num_outputs, 375 sparse_tensors=sparse_tensors, 376 verify_exhausted=verify_exhausted) 377 378 self.match(expected, actual) 379 380 def gen_outputs(self, 381 ds_fn, 382 break_points, 383 num_outputs, 384 ckpt_saved=False, 385 sparse_tensors=False, 386 verify_exhausted=True, 387 save_checkpoint_at_end=True): 388 """Generates elements from input dataset while stopping at break points. 389 390 Produces `num_outputs` outputs and saves the state of the iterator in the 391 Saver checkpoint. 392 393 Args: 394 ds_fn: 0-argument function that returns the dataset. 395 break_points: A list of integers. For each `break_point` in 396 `break_points`, we produce outputs till `break_point` number of items 397 have been produced and then checkpoint the state. The current graph and 398 session are destroyed and a new graph and session are used to produce 399 outputs till next checkpoint or till `num_outputs` elements have been 400 produced. `break_point` must be <= `num_outputs`. 401 num_outputs: The total number of outputs to produce from the iterator. 402 ckpt_saved: Whether a checkpoint already exists. 403 sparse_tensors: Whether dataset is built from SparseTensor(s). 404 verify_exhausted: Whether to verify that the iterator has been exhausted 405 after producing `num_outputs` elements. 406 save_checkpoint_at_end: Whether to save a checkpoint after producing all 407 outputs. If False, checkpoints are saved each break point but not at the 408 end. Note that checkpoints overwrite each other so there is always only 409 a single checkpoint available. Defaults to True. 410 411 Returns: 412 A list of `num_outputs` items. 413 """ 414 outputs = [] 415 416 if context.executing_eagerly(): 417 for i in range(len(break_points) + 1): 418 iterator = iter(ds_fn()) 419 ckpt = tracking_util.Checkpoint(iterator=iterator) 420 if ckpt_saved: 421 ckpt_path = self._latest_ckpt() 422 ckpt.restore(ckpt_path) 423 start = break_points[i - 1] if i > 0 else 0 424 end = break_points[i] if i < len(break_points) else num_outputs 425 num_iters = end - start 426 for _ in range(num_iters): 427 outputs.append(self.evaluate(next(iterator))) 428 if i == len(break_points) and verify_exhausted: 429 with self.assertRaises(StopIteration): 430 next(iterator) 431 if save_checkpoint_at_end or i < len(break_points): 432 ckpt_path = ckpt.save(self._ckpt_path()) 433 ckpt_saved = True 434 else: 435 def get_ops(): 436 if ckpt_saved: 437 saver = self._import_meta_graph() 438 init_op, get_next_op = self._get_iterator_ops_from_collection( 439 ds_fn, sparse_tensors=sparse_tensors) 440 else: 441 init_op, get_next_op, saver = self._build_graph( 442 ds_fn, sparse_tensors=sparse_tensors) 443 return init_op, get_next_op, saver 444 445 for i in range(len(break_points) + 1): 446 with ops.Graph().as_default() as g: 447 init_op, get_next_op, saver = get_ops() 448 get_next_op = remove_variants(get_next_op) 449 with self.session(graph=g) as sess: 450 if ckpt_saved: 451 self._initialize(init_op, sess) 452 self._restore(saver, sess) 453 else: 454 self._initialize(init_op, sess) 455 start = break_points[i - 1] if i > 0 else 0 456 end = break_points[i] if i < len(break_points) else num_outputs 457 num_iters = end - start 458 for _ in range(num_iters): 459 outputs.append(sess.run(get_next_op)) 460 if i == len(break_points) and verify_exhausted: 461 with self.assertRaises(errors.OutOfRangeError): 462 sess.run(get_next_op) 463 if save_checkpoint_at_end or i < len(break_points): 464 self._save(sess, saver) 465 ckpt_saved = True 466 467 return outputs 468 469 def match(self, expected, actual): 470 """Matches nested structures. 471 472 Recursively matches shape and values of `expected` and `actual`. 473 Handles scalars, numpy arrays and other python sequence containers 474 e.g. list, dict, as well as SparseTensorValue and RaggedTensorValue. 475 476 Args: 477 expected: Nested structure 1. 478 actual: Nested structure 2. 479 480 Raises: 481 AssertionError if matching fails. 482 """ 483 if isinstance(expected, np.ndarray): 484 expected = expected.tolist() 485 if isinstance(actual, np.ndarray): 486 actual = actual.tolist() 487 self.assertEqual(type(expected), type(actual)) 488 489 if nest.is_sequence(expected): 490 self.assertEqual(len(expected), len(actual)) 491 if isinstance(expected, dict): 492 for key1, key2 in zip(sorted(expected), sorted(actual)): 493 self.assertEqual(key1, key2) 494 self.match(expected[key1], actual[key2]) 495 else: 496 for item1, item2 in zip(expected, actual): 497 self.match(item1, item2) 498 elif isinstance(expected, sparse_tensor.SparseTensorValue): 499 self.match((expected.indices, expected.values, expected.dense_shape), 500 (actual.indices, actual.values, actual.dense_shape)) 501 elif isinstance(expected, ragged_tensor_value.RaggedTensorValue): 502 self.match((expected.values, expected.row_splits), 503 (actual.values, actual.row_splits)) 504 else: 505 self.assertEqual(expected, actual) 506 507 def does_not_match(self, expected, actual): 508 with self.assertRaises(AssertionError): 509 self.match(expected, actual) 510 511 def gen_break_points(self, num_outputs, num_samples=10): 512 """Generates `num_samples` breaks points in [0, num_outputs].""" 513 return np.linspace(0, num_outputs, num_samples, dtype=int) 514 515 def _build_graph(self, ds_fn, sparse_tensors=False): 516 dataset = ds_fn() 517 iterator = dataset_ops.make_initializable_iterator(dataset) 518 external_state_policy = dataset.options().experimental_external_state_policy 519 saveable = contrib_iterator_ops.make_saveable_from_iterator( 520 iterator, external_state_policy=external_state_policy) 521 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 522 init_op = iterator.initializer 523 if sparse_tensors: 524 get_next = sparse_tensor.SparseTensor(*iterator.get_next()) 525 else: 526 get_next = iterator.get_next() 527 self._add_iterator_ops_to_collection(init_op, get_next, ds_fn, 528 sparse_tensors) 529 saver = saver_lib.Saver(allow_empty=True) 530 return init_op, get_next, saver 531 532 def _add_iterator_ops_to_collection(self, 533 init_op, 534 get_next, 535 ds_fn, 536 sparse_tensors=False): 537 ops.add_to_collection("iterator_ops", init_op) 538 # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections 539 # do not support tuples we flatten the tensors and restore the shape in 540 # `_get_iterator_ops_from_collection`. 541 if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. 542 ops.add_to_collection("iterator_ops", get_next.indices) 543 ops.add_to_collection("iterator_ops", get_next.values) 544 ops.add_to_collection("iterator_ops", get_next.dense_shape) 545 return 546 547 get_next_list = nest.flatten(get_next) 548 for i, output_class in enumerate( 549 nest.flatten(self._get_output_classes(ds_fn))): 550 if output_class is sparse_tensor.SparseTensor: 551 ops.add_to_collection("iterator_ops", get_next_list[i].indices) 552 ops.add_to_collection("iterator_ops", get_next_list[i].values) 553 ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape) 554 else: 555 ops.add_to_collection("iterator_ops", get_next_list[i]) 556 557 def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): 558 all_ops = ops.get_collection("iterator_ops") 559 if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. 560 init_op, indices, values, dense_shape = all_ops 561 return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) 562 get_next_list = [] 563 i = 1 564 for output_class in nest.flatten(self._get_output_classes(ds_fn)): 565 if output_class is sparse_tensor.SparseTensor: 566 indices, values, dense_shape = all_ops[i:i + 3] 567 i += 3 568 get_next_list.append( 569 sparse_tensor.SparseTensor(indices, values, dense_shape)) 570 else: 571 get_next_list.append(all_ops[i]) 572 i += 1 573 return all_ops[0], nest.pack_sequence_as( 574 self._get_output_types(ds_fn), get_next_list) 575 576 def _get_output_types(self, ds_fn): 577 assert not context.executing_eagerly() 578 with ops.Graph().as_default(): 579 return dataset_ops.get_legacy_output_types(ds_fn()) 580 581 def _get_output_shapes(self, ds_fn): 582 assert not context.executing_eagerly() 583 with ops.Graph().as_default(): 584 return dataset_ops.get_legacy_output_shapes(ds_fn()) 585 586 def _get_output_classes(self, ds_fn): 587 assert not context.executing_eagerly() 588 with ops.Graph().as_default(): 589 return dataset_ops.get_legacy_output_classes(ds_fn()) 590 591 def _ckpt_path(self): 592 return os.path.join(self.get_temp_dir(), "iterator") 593 594 def _latest_ckpt(self): 595 return checkpoint_management.latest_checkpoint(self.get_temp_dir()) 596 597 def _save(self, sess, saver): 598 saver.save(sess, self._ckpt_path()) 599 600 def _restore(self, saver, sess): 601 sess.run(lookup_ops.tables_initializer()) 602 saver.restore(sess, self._latest_ckpt()) 603 604 def _initialize(self, init_op, sess): 605 sess.run(variables.global_variables_initializer()) 606 sess.run(lookup_ops.tables_initializer()) 607 sess.run(init_op) 608 609 def _import_meta_graph(self): 610 meta_file_path = self._ckpt_path() + ".meta" 611 return saver_lib.import_meta_graph(meta_file_path) 612 613 def _delete_ckpt(self): 614 # Remove all checkpoint files. 615 prefix = self._ckpt_path() 616 pattern = prefix + "*" 617 files = gfile.Glob(pattern) 618 map(gfile.Remove, files) 619