• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15from __future__ import absolute_import
16from __future__ import division
17from __future__ import print_function
18
19import os
20
21from tensorflow.contrib.checkpoint.python import split_dependency
22from tensorflow.python.eager import test
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import resource_variable_ops
26from tensorflow.python.training.tracking import base
27from tensorflow.python.training.tracking import tracking
28from tensorflow.python.training.tracking import util
29
30
31def _split_variable_closure(variable):
32  def _fill_save_buffer_fn(save_buffer):
33    save_buffer["first_half"] = variable[:2]
34    save_buffer["second_half"] = variable[2:]
35  return _fill_save_buffer_fn
36
37
38def _combine_variable_closure(variable):
39  def _consume_restore_buffer_fn(restore_buffer):
40    return variable.assign(
41        array_ops.concat([restore_buffer["first_half"],
42                          restore_buffer["second_half"]],
43                         axis=0))
44  return _consume_restore_buffer_fn
45
46
47class SaveTensorSlicesAsDeps(base.Trackable):
48
49  def __init__(self):
50    self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.])
51    split_dependencies = split_dependency.split_dependency(
52        component_names=("first_half", "second_half"),
53        component_dtypes=(self.combined.dtype,) * 2,
54        fill_save_buffer_fn=_split_variable_closure(
55            self.combined),
56        consume_restore_buffer_fn=_combine_variable_closure(
57            self.combined))
58    for name, dep in split_dependencies.items():
59      self._track_trackable(dep, name=name)
60
61
62class HasRegularDeps(tracking.AutoTrackable):
63
64  def __init__(self):
65    self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
66    self.second_half = resource_variable_ops.ResourceVariable([0., 0.])
67
68
69class OnlyOneDep(tracking.AutoTrackable):
70
71  def __init__(self):
72    self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
73
74
75class SplitTests(test.TestCase):
76
77  @test_util.run_in_graph_and_eager_modes
78  def testSaveRestoreSplitDep(self):
79    save_checkpoint = util.Checkpoint(
80        dep=SaveTensorSlicesAsDeps())
81    self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.]))
82    checkpoint_directory = self.get_temp_dir()
83    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
84    save_path = save_checkpoint.save(checkpoint_prefix)
85
86    regular_deps = HasRegularDeps()
87    regular_restore_checkpoint = util.Checkpoint(
88        dep=regular_deps)
89    regular_restore_checkpoint.restore(
90        save_path).assert_consumed().run_restore_ops()
91    self.assertAllEqual([1., 2.], self.evaluate(regular_deps.first_half))
92    self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half))
93
94    one_dep = OnlyOneDep()
95    one_dep_restore_checkpoint = util.Checkpoint(dep=one_dep)
96    status = one_dep_restore_checkpoint.restore(save_path)
97    with self.assertRaises(AssertionError):
98      # Missing the second dependency.
99      status.assert_consumed()
100    status.run_restore_ops()
101    self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half))
102
103    restore_checkpoint = util.Checkpoint()
104    status = restore_checkpoint.restore(save_path)
105    restore_checkpoint.dep = SaveTensorSlicesAsDeps()
106    status.assert_consumed().run_restore_ops()
107    self.assertAllEqual(
108        [1., 2., 3., 4.],
109        self.evaluate(restore_checkpoint.dep.combined))
110
111
112if __name__ == "__main__":
113  test.main()
114