• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base test class for checkpointing datasets."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23import numpy as np
24
25from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.data.ops import options as options_lib
28from tensorflow.python.eager import context
29from tensorflow.python.framework import combinations
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import sparse_tensor
34from tensorflow.python.ops import lookup_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.ops.ragged import ragged_tensor_value
37from tensorflow.python.platform import gfile
38from tensorflow.python.platform import test
39from tensorflow.python.training import checkpoint_management
40from tensorflow.python.training import saver as saver_lib
41from tensorflow.python.training.tracking import util as tracking_util
42from tensorflow.python.util import nest
43
44
45def remove_variants(get_next_op):
46  # TODO(b/72408568): Remove this once session.run can get variant tensors.
47  """Remove variants from a nest structure, so sess.run will execute."""
48
49  def _remove_variant(x):
50    if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
51      return ()
52    else:
53      return x
54
55  return nest.map_structure(_remove_variant, get_next_op)
56
57
58def default_test_combinations():
59  """Returns the default test combinations for testing checkpointing."""
60
61  def disable_optimizations(ds_fn):
62    options = options_lib.Options()
63    options.experimental_optimization.apply_default_optimizations = False
64
65    def ds_fn_no_opt():
66      return ds_fn().with_options(options)
67
68    return ds_fn_no_opt
69
70  def verify_unused_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
71    obj.verify_unused_iterator(
72        ds_fn=disable_optimizations(ds_fn=ds_fn),
73        num_outputs=num_outputs,
74        sparse_tensors=sparse_tensors)
75
76  verify_unused_iterator_combination = combinations.combine(
77      verify_fn=combinations.NamedObject("verify_unused_iterator",
78                                         verify_unused_iterator))
79
80  def verify_fully_used_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
81    obj.verify_fully_used_iterator(
82        ds_fn=disable_optimizations(ds_fn=ds_fn),
83        num_outputs=num_outputs,
84        sparse_tensors=sparse_tensors)
85
86  verify_fully_used_iterator_combination = combinations.combine(
87      verify_fn=combinations.NamedObject("verify_fully_used_iterator",
88                                         verify_fully_used_iterator))
89
90  def verify_exhausted_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
91    obj.verify_exhausted_iterator(
92        ds_fn=disable_optimizations(ds_fn=ds_fn),
93        num_outputs=num_outputs,
94        sparse_tensors=sparse_tensors)
95
96  verify_exhausted_iterator_combination = combinations.combine(
97      verify_fn=combinations.NamedObject("verify_exhausted_iterator",
98                                         verify_exhausted_iterator))
99
100  def verify_multiple_breaks(obj, ds_fn, num_outputs, sparse_tensors=False):
101    obj.verify_multiple_breaks(
102        ds_fn=disable_optimizations(ds_fn=ds_fn),
103        num_outputs=num_outputs,
104        sparse_tensors=sparse_tensors)
105
106  verify_multiple_breaks_combination = combinations.combine(
107      verify_fn=combinations.NamedObject("verify_multiple_breaks",
108                                         verify_multiple_breaks))
109
110  def verify_reset_restored_iterator(obj,
111                                     ds_fn,
112                                     num_outputs,
113                                     sparse_tensors=False):
114    obj.verify_reset_restored_iterator(
115        ds_fn=disable_optimizations(ds_fn=ds_fn),
116        num_outputs=num_outputs,
117        sparse_tensors=sparse_tensors)
118
119  verify_reset_restored_iterator_combination = combinations.combine(
120      verify_fn=combinations.NamedObject("verify_reset_restored_iterator",
121                                         verify_reset_restored_iterator))
122
123  return (verify_unused_iterator_combination +
124          verify_fully_used_iterator_combination +
125          verify_exhausted_iterator_combination +
126          verify_multiple_breaks_combination +
127          verify_reset_restored_iterator_combination)
128
129
130# TODO(b/72657739): Remove sparse_tensor argument, which is to test the
131# (deprecated) saveable `SparseTensorSliceDataset`, once the API
132# `from_sparse_tensor_slices()` and related tests are deleted.
133class CheckpointTestBase(test.TestCase):
134  """Base test class for checkpointing datasets."""
135
136  def tearDown(self):
137    self._delete_ckpt()
138    super(CheckpointTestBase, self).tearDown()
139
140  def verify_unused_iterator(self,
141                             ds_fn,
142                             num_outputs,
143                             sparse_tensors=False,
144                             verify_exhausted=True):
145    """Verifies that saving and restoring an unused iterator works.
146
147    Args:
148      ds_fn: 0-argument function that returns a Dataset.
149      num_outputs: Total number of outputs expected from this Dataset.
150      sparse_tensors: Whether dataset is built from SparseTensor(s).
151      verify_exhausted: Whether to verify that the iterator has been exhausted
152        after producing `num_outputs` elements.
153
154    Raises:
155      AssertionError if any test fails.
156    """
157    self.verify_run_with_breaks(
158        ds_fn, [0],
159        num_outputs,
160        sparse_tensors=sparse_tensors,
161        verify_exhausted=verify_exhausted)
162
163  def verify_fully_used_iterator(self,
164                                 ds_fn,
165                                 num_outputs,
166                                 sparse_tensors=False):
167    """Verifies that saving and restoring a fully used iterator works.
168
169    Note that this only checks saving and restoring an iterator from which
170    `num_outputs` items have been produced but does not check for an
171    exhausted iterator, i.e., one from which an OutOfRange error has been
172    returned.
173
174    Args:
175      ds_fn: 0-argument function that returns a Dataset.
176      num_outputs: Total number of outputs expected from this Dataset.
177      sparse_tensors: Whether dataset is built from SparseTensor(s).
178
179    Raises:
180      AssertionError if test fails.
181    """
182    self.verify_run_with_breaks(
183        ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
184
185  def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
186    """Verifies that saving and restoring an exhausted iterator works.
187
188    An exhausted iterator is one which has returned an OutOfRange error.
189
190    Args:
191      ds_fn: 0-argument function that returns a Dataset.
192      num_outputs: Total number of outputs expected from this Dataset.
193      sparse_tensors: Whether dataset is built from SparseTensor(s).
194
195    Raises:
196      AssertionError if any test fails.
197    """
198    self.gen_outputs(
199        ds_fn, [],
200        num_outputs,
201        verify_exhausted=True,
202        sparse_tensors=sparse_tensors)
203    actual = self.gen_outputs(
204        ds_fn, [],
205        0,
206        ckpt_saved=True,
207        verify_exhausted=True,
208        sparse_tensors=sparse_tensors)
209    self.assertLen(actual, 0)
210
211  def verify_multiple_breaks(self,
212                             ds_fn,
213                             num_outputs,
214                             num_breaks=10,
215                             sparse_tensors=False,
216                             verify_exhausted=True):
217    """Attempts to save/restore at multiple break points.
218
219    Args:
220      ds_fn: 0-argument function that returns a Dataset.
221      num_outputs: Total number of outputs expected from this Dataset.
222      num_breaks: The number of break points. These are uniformly spread in [0,
223        num_outputs] both inclusive.
224      sparse_tensors: Whether dataset is built from SparseTensor(s).
225      verify_exhausted: Whether to verify that the iterator has been exhausted
226        after producing `num_outputs` elements.
227
228    Raises:
229      AssertionError if any test fails.
230    """
231    self.verify_run_with_breaks(
232        ds_fn,
233        self.gen_break_points(num_outputs, num_breaks),
234        num_outputs,
235        sparse_tensors=sparse_tensors,
236        verify_exhausted=verify_exhausted)
237
238  def verify_reset_restored_iterator(self,
239                                     ds_fn,
240                                     num_outputs,
241                                     break_point=None,
242                                     sparse_tensors=False,
243                                     verify_exhausted=True):
244    """Attempts to re-initialize a restored iterator.
245
246    This is useful when restoring a training checkpoint during validation.
247
248    Args:
249      ds_fn: 0-argument function that returns a Dataset.
250      num_outputs: Total number of outputs expected from this Dataset.
251      break_point: Break point. Optional. Defaults to num_outputs/2.
252      sparse_tensors: Whether dataset is built from SparseTensor(s).
253      verify_exhausted: Whether to verify that the iterator has been exhausted
254        after producing `num_outputs` elements.
255
256    Raises:
257      AssertionError if any test fails.
258    """
259    if context.executing_eagerly():
260      self.skipTest("Eager mode iteration do not support re-initialization.")
261
262    break_point = num_outputs // 2 if not break_point else break_point
263
264    # Collect ground truth containing all outputs.
265    expected = self.gen_outputs(
266        ds_fn, [],
267        num_outputs,
268        sparse_tensors=sparse_tensors,
269        verify_exhausted=verify_exhausted)
270
271    # Skip some items and save checkpoint.
272    self.gen_outputs(
273        ds_fn, [],
274        break_point,
275        sparse_tensors=sparse_tensors,
276        verify_exhausted=False)
277
278    actual = []
279    # Restore from checkpoint and then run init_op.
280    with ops.Graph().as_default() as g:
281      saver = self._import_meta_graph()
282      init_op, get_next_op = self._get_iterator_ops_from_collection(
283          ds_fn, sparse_tensors=sparse_tensors)
284      get_next_op = remove_variants(get_next_op)
285      with self.session(graph=g) as sess:
286        self._initialize(init_op, sess)
287        self._restore(saver, sess)
288        self._initialize(init_op, sess)
289        for _ in range(num_outputs):
290          actual.append(sess.run(get_next_op))
291        if verify_exhausted:
292          with self.assertRaises(errors.OutOfRangeError):
293            sess.run(get_next_op)
294    self.match(expected, actual)
295
296  def verify_error_on_save(self,
297                           ds_fn,
298                           num_outputs,
299                           error,
300                           break_point=None,
301                           sparse_tensors=False):
302    """Attempts to save a non-saveable iterator.
303
304    Args:
305      ds_fn: 0-argument function that returns a Dataset.
306      num_outputs: Total number of outputs expected from this Dataset.
307      error: Declared error when trying to save iterator.
308      break_point: Break point. Optional. Defaults to num_outputs/2.
309      sparse_tensors: Whether dataset is built from SparseTensor(s).
310
311    Raises:
312      AssertionError if any test fails.
313    """
314    break_point = num_outputs // 2 if not break_point else break_point
315    if context.executing_eagerly():
316      iterator = iter(ds_fn())
317      ckpt = tracking_util.Checkpoint(iterator=iterator)
318      for _ in range(break_point):
319        next(iterator)
320      with self.assertRaises(error):
321        ckpt.save(self._ckpt_path())
322    else:
323      with ops.Graph().as_default() as g:
324        init_op, get_next_op, saver = self._build_graph(
325            ds_fn, sparse_tensors=sparse_tensors)
326        get_next_op = remove_variants(get_next_op)
327        with self.session(graph=g) as sess:
328          self._initialize(init_op, sess)
329          for _ in range(break_point):
330            sess.run(get_next_op)
331          with self.assertRaises(error):
332            self._save(sess, saver)
333
334  def verify_run_with_breaks(self,
335                             ds_fn,
336                             break_points,
337                             num_outputs,
338                             sparse_tensors=False,
339                             verify_exhausted=True):
340    """Verifies that ds_fn() produces the same outputs with and without breaks.
341
342    1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
343       *without* stopping at break points.
344    2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
345       with stopping at break points.
346
347    Deep matches outputs from 1 and 2.
348
349    Args:
350      ds_fn: 0-argument function that returns a Dataset.
351      break_points: A list of integers. For each `break_point` in
352        `break_points`, we produce outputs till `break_point` number of items
353        have been produced and then checkpoint the state. The current graph and
354        session are destroyed and a new graph and session are used to produce
355        outputs till next checkpoint or till `num_outputs` elements have been
356        produced. `break_point` must be <= `num_outputs`.
357      num_outputs: Total number of outputs expected from this Dataset.
358      sparse_tensors: Whether dataset is built from SparseTensor(s).
359      verify_exhausted: Whether to verify that the iterator has been exhausted
360        after producing `num_outputs` elements.
361
362    Raises:
363      AssertionError if any test fails.
364    """
365    expected = self.gen_outputs(
366        ds_fn, [],
367        num_outputs,
368        sparse_tensors=sparse_tensors,
369        verify_exhausted=verify_exhausted)
370
371    actual = self.gen_outputs(
372        ds_fn,
373        break_points,
374        num_outputs,
375        sparse_tensors=sparse_tensors,
376        verify_exhausted=verify_exhausted)
377
378    self.match(expected, actual)
379
380  def gen_outputs(self,
381                  ds_fn,
382                  break_points,
383                  num_outputs,
384                  ckpt_saved=False,
385                  sparse_tensors=False,
386                  verify_exhausted=True,
387                  save_checkpoint_at_end=True):
388    """Generates elements from input dataset while stopping at break points.
389
390    Produces `num_outputs` outputs and saves the state of the iterator in the
391    Saver checkpoint.
392
393    Args:
394      ds_fn: 0-argument function that returns the dataset.
395      break_points: A list of integers. For each `break_point` in
396        `break_points`, we produce outputs till `break_point` number of items
397        have been produced and then checkpoint the state. The current graph and
398        session are destroyed and a new graph and session are used to produce
399        outputs till next checkpoint or till `num_outputs` elements have been
400        produced. `break_point` must be <= `num_outputs`.
401      num_outputs: The total number of outputs to produce from the iterator.
402      ckpt_saved: Whether a checkpoint already exists.
403      sparse_tensors:  Whether dataset is built from SparseTensor(s).
404      verify_exhausted: Whether to verify that the iterator has been exhausted
405        after producing `num_outputs` elements.
406      save_checkpoint_at_end: Whether to save a checkpoint after producing all
407        outputs. If False, checkpoints are saved each break point but not at the
408        end. Note that checkpoints overwrite each other so there is always only
409        a single checkpoint available. Defaults to True.
410
411    Returns:
412      A list of `num_outputs` items.
413    """
414    outputs = []
415
416    if context.executing_eagerly():
417      for i in range(len(break_points) + 1):
418        iterator = iter(ds_fn())
419        ckpt = tracking_util.Checkpoint(iterator=iterator)
420        if ckpt_saved:
421          ckpt_path = self._latest_ckpt()
422          ckpt.restore(ckpt_path)
423        start = break_points[i - 1] if i > 0 else 0
424        end = break_points[i] if i < len(break_points) else num_outputs
425        num_iters = end - start
426        for _ in range(num_iters):
427          outputs.append(self.evaluate(next(iterator)))
428        if i == len(break_points) and verify_exhausted:
429          with self.assertRaises(StopIteration):
430            next(iterator)
431        if save_checkpoint_at_end or i < len(break_points):
432          ckpt_path = ckpt.save(self._ckpt_path())
433          ckpt_saved = True
434    else:
435      def get_ops():
436        if ckpt_saved:
437          saver = self._import_meta_graph()
438          init_op, get_next_op = self._get_iterator_ops_from_collection(
439              ds_fn, sparse_tensors=sparse_tensors)
440        else:
441          init_op, get_next_op, saver = self._build_graph(
442              ds_fn, sparse_tensors=sparse_tensors)
443        return init_op, get_next_op, saver
444
445      for i in range(len(break_points) + 1):
446        with ops.Graph().as_default() as g:
447          init_op, get_next_op, saver = get_ops()
448          get_next_op = remove_variants(get_next_op)
449          with self.session(graph=g) as sess:
450            if ckpt_saved:
451              self._initialize(init_op, sess)
452              self._restore(saver, sess)
453            else:
454              self._initialize(init_op, sess)
455            start = break_points[i - 1] if i > 0 else 0
456            end = break_points[i] if i < len(break_points) else num_outputs
457            num_iters = end - start
458            for _ in range(num_iters):
459              outputs.append(sess.run(get_next_op))
460            if i == len(break_points) and verify_exhausted:
461              with self.assertRaises(errors.OutOfRangeError):
462                sess.run(get_next_op)
463            if save_checkpoint_at_end or i < len(break_points):
464              self._save(sess, saver)
465              ckpt_saved = True
466
467    return outputs
468
469  def match(self, expected, actual):
470    """Matches nested structures.
471
472    Recursively matches shape and values of `expected` and `actual`.
473    Handles scalars, numpy arrays and other python sequence containers
474    e.g. list, dict, as well as SparseTensorValue and RaggedTensorValue.
475
476    Args:
477      expected: Nested structure 1.
478      actual: Nested structure 2.
479
480    Raises:
481      AssertionError if matching fails.
482    """
483    if isinstance(expected, np.ndarray):
484      expected = expected.tolist()
485    if isinstance(actual, np.ndarray):
486      actual = actual.tolist()
487    self.assertEqual(type(expected), type(actual))
488
489    if nest.is_sequence(expected):
490      self.assertEqual(len(expected), len(actual))
491      if isinstance(expected, dict):
492        for key1, key2 in zip(sorted(expected), sorted(actual)):
493          self.assertEqual(key1, key2)
494          self.match(expected[key1], actual[key2])
495      else:
496        for item1, item2 in zip(expected, actual):
497          self.match(item1, item2)
498    elif isinstance(expected, sparse_tensor.SparseTensorValue):
499      self.match((expected.indices, expected.values, expected.dense_shape),
500                 (actual.indices, actual.values, actual.dense_shape))
501    elif isinstance(expected, ragged_tensor_value.RaggedTensorValue):
502      self.match((expected.values, expected.row_splits),
503                 (actual.values, actual.row_splits))
504    else:
505      self.assertEqual(expected, actual)
506
507  def does_not_match(self, expected, actual):
508    with self.assertRaises(AssertionError):
509      self.match(expected, actual)
510
511  def gen_break_points(self, num_outputs, num_samples=10):
512    """Generates `num_samples` breaks points in [0, num_outputs]."""
513    return np.linspace(0, num_outputs, num_samples, dtype=int)
514
515  def _build_graph(self, ds_fn, sparse_tensors=False):
516    dataset = ds_fn()
517    iterator = dataset_ops.make_initializable_iterator(dataset)
518    external_state_policy = dataset.options().experimental_external_state_policy
519    saveable = contrib_iterator_ops.make_saveable_from_iterator(
520        iterator, external_state_policy=external_state_policy)
521    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
522    init_op = iterator.initializer
523    if sparse_tensors:
524      get_next = sparse_tensor.SparseTensor(*iterator.get_next())
525    else:
526      get_next = iterator.get_next()
527    self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
528                                         sparse_tensors)
529    saver = saver_lib.Saver(allow_empty=True)
530    return init_op, get_next, saver
531
532  def _add_iterator_ops_to_collection(self,
533                                      init_op,
534                                      get_next,
535                                      ds_fn,
536                                      sparse_tensors=False):
537    ops.add_to_collection("iterator_ops", init_op)
538    # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
539    # do not support tuples we flatten the tensors and restore the shape in
540    # `_get_iterator_ops_from_collection`.
541    if sparse_tensors:  # specific for deprecated `from_sparse_tensor_slices`.
542      ops.add_to_collection("iterator_ops", get_next.indices)
543      ops.add_to_collection("iterator_ops", get_next.values)
544      ops.add_to_collection("iterator_ops", get_next.dense_shape)
545      return
546
547    get_next_list = nest.flatten(get_next)
548    for i, output_class in enumerate(
549        nest.flatten(self._get_output_classes(ds_fn))):
550      if output_class is sparse_tensor.SparseTensor:
551        ops.add_to_collection("iterator_ops", get_next_list[i].indices)
552        ops.add_to_collection("iterator_ops", get_next_list[i].values)
553        ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
554      else:
555        ops.add_to_collection("iterator_ops", get_next_list[i])
556
557  def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
558    all_ops = ops.get_collection("iterator_ops")
559    if sparse_tensors:  # specific for deprecated `from_sparse_tensor_slices`.
560      init_op, indices, values, dense_shape = all_ops
561      return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
562    get_next_list = []
563    i = 1
564    for output_class in nest.flatten(self._get_output_classes(ds_fn)):
565      if output_class is sparse_tensor.SparseTensor:
566        indices, values, dense_shape = all_ops[i:i + 3]
567        i += 3
568        get_next_list.append(
569            sparse_tensor.SparseTensor(indices, values, dense_shape))
570      else:
571        get_next_list.append(all_ops[i])
572        i += 1
573    return all_ops[0], nest.pack_sequence_as(
574        self._get_output_types(ds_fn), get_next_list)
575
576  def _get_output_types(self, ds_fn):
577    assert not context.executing_eagerly()
578    with ops.Graph().as_default():
579      return dataset_ops.get_legacy_output_types(ds_fn())
580
581  def _get_output_shapes(self, ds_fn):
582    assert not context.executing_eagerly()
583    with ops.Graph().as_default():
584      return dataset_ops.get_legacy_output_shapes(ds_fn())
585
586  def _get_output_classes(self, ds_fn):
587    assert not context.executing_eagerly()
588    with ops.Graph().as_default():
589      return dataset_ops.get_legacy_output_classes(ds_fn())
590
591  def _ckpt_path(self):
592    return os.path.join(self.get_temp_dir(), "iterator")
593
594  def _latest_ckpt(self):
595    return checkpoint_management.latest_checkpoint(self.get_temp_dir())
596
597  def _save(self, sess, saver):
598    saver.save(sess, self._ckpt_path())
599
600  def _restore(self, saver, sess):
601    sess.run(lookup_ops.tables_initializer())
602    saver.restore(sess, self._latest_ckpt())
603
604  def _initialize(self, init_op, sess):
605    sess.run(variables.global_variables_initializer())
606    sess.run(lookup_ops.tables_initializer())
607    sess.run(init_op)
608
609  def _import_meta_graph(self):
610    meta_file_path = self._ckpt_path() + ".meta"
611    return saver_lib.import_meta_graph(meta_file_path)
612
613  def _delete_ckpt(self):
614    # Remove all checkpoint files.
615    prefix = self._ckpt_path()
616    pattern = prefix + "*"
617    files = gfile.Glob(pattern)
618    map(gfile.Remove, files)
619