• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for saveable_object_util."""
16
17import os
18
19from tensorflow.python.checkpoint import checkpoint
20from tensorflow.python.checkpoint import saveable_compat
21from tensorflow.python.eager import context
22from tensorflow.python.eager import test
23from tensorflow.python.framework import dtypes
24from tensorflow.python.ops import gen_resource_variable_ops
25from tensorflow.python.ops import resource_variable_ops
26from tensorflow.python.ops import variables
27from tensorflow.python.trackable import base
28from tensorflow.python.trackable import resource
29from tensorflow.python.training.saving import saveable_object
30from tensorflow.python.training.saving import saveable_object_util
31
32
33class _VarSaveable(saveable_object.SaveableObject):
34
35  def __init__(self, var, slice_spec, name):
36    specs = [saveable_object.SaveSpec(var.read_value(), slice_spec, name)]
37    super().__init__(var, specs, name)
38
39  def restore(self, restored_tensors, restored_shapes):
40    return self.op.assign(restored_tensors[0])
41
42
43class SaveableCompatibilityConverterTest(test.TestCase):
44
45  def test_convert_no_saveable(self):
46    t = base.Trackable()
47    converter = saveable_object_util.SaveableCompatibilityConverter(t)
48    self.assertEmpty(converter._serialize_to_tensors())
49    converter._restore_from_tensors({})
50
51    with self.assertRaisesRegex(ValueError, "Could not restore object"):
52      converter._restore_from_tensors({"": 0})
53
54  def test_convert_single_saveable(self):
55
56    class MyTrackable(base.Trackable):
57
58      def __init__(self):
59        self.a = variables.Variable(5.0)
60
61      def _gather_saveables_for_checkpoint(self):
62        return {"a": lambda name: _VarSaveable(self.a, "", name)}
63
64    t = MyTrackable()
65    converter = saveable_object_util.SaveableCompatibilityConverter(t)
66
67    serialized_tensors = converter._serialize_to_tensors()
68    self.assertLen(serialized_tensors, 1)
69    self.assertIn("a", serialized_tensors)
70    self.assertEqual(5, self.evaluate(serialized_tensors["a"]))
71
72    with self.assertRaisesRegex(ValueError, "Could not restore object"):
73      converter._restore_from_tensors({})
74    with self.assertRaisesRegex(ValueError, "Could not restore object"):
75      converter._restore_from_tensors({"not_a": 1.})
76
77    self.assertEqual(5, self.evaluate(t.a))
78    converter._restore_from_tensors({"a": 123.})
79    self.assertEqual(123, self.evaluate(t.a))
80
81  def test_convert_single_saveable_renamed(self):
82
83    class MyTrackable(base.Trackable):
84
85      def __init__(self):
86        self.a = variables.Variable(15.0)
87
88      def _gather_saveables_for_checkpoint(self):
89        return {"a": lambda name: _VarSaveable(self.a, "", name + "-value")}
90
91    t = MyTrackable()
92    converter = saveable_object_util.SaveableCompatibilityConverter(t)
93
94    serialized_tensors = converter._serialize_to_tensors()
95
96    self.assertLen(serialized_tensors, 1)
97    self.assertEqual(15, self.evaluate(serialized_tensors["a-value"]))
98
99    with self.assertRaisesRegex(ValueError, "Could not restore object"):
100      converter._restore_from_tensors({"a": 1.})
101
102    self.assertEqual(15, self.evaluate(t.a))
103    converter._restore_from_tensors({"a-value": 456.})
104    self.assertEqual(456, self.evaluate(t.a))
105
106  def test_convert_multiple_saveables(self):
107
108    class MyTrackable(base.Trackable):
109
110      def __init__(self):
111        self.a = variables.Variable(15.0)
112        self.b = variables.Variable(20.0)
113
114      def _gather_saveables_for_checkpoint(self):
115        return {
116            "a": lambda name: _VarSaveable(self.a, "", name + "-1"),
117            "b": lambda name: _VarSaveable(self.b, "", name + "-2")}
118
119    t = MyTrackable()
120    converter = saveable_object_util.SaveableCompatibilityConverter(t)
121
122    serialized_tensors = converter._serialize_to_tensors()
123    self.assertLen(serialized_tensors, 2)
124    self.assertEqual(15, self.evaluate(serialized_tensors["a-1"]))
125    self.assertEqual(20, self.evaluate(serialized_tensors["b-2"]))
126
127    with self.assertRaisesRegex(ValueError, "Could not restore object"):
128      converter._restore_from_tensors({"a": 1., "b": 2.})
129    with self.assertRaisesRegex(ValueError, "Could not restore object"):
130      converter._restore_from_tensors({"b-2": 2.})
131
132    converter._restore_from_tensors({"a-1": -123., "b-2": -456.})
133    self.assertEqual(-123, self.evaluate(t.a))
134    self.assertEqual(-456, self.evaluate(t.b))
135
136  def test_convert_variables(self):
137    # The method `_gather_saveables_for_checkpoint` allowed the users to pass
138    # Variables instead of Saveables.
139
140    class MyTrackable(base.Trackable):
141
142      def __init__(self):
143        self.a = variables.Variable(25.)
144        self.b = resource_variable_ops.UninitializedVariable(
145            dtype=dtypes.float32)
146
147      def _gather_saveables_for_checkpoint(self):
148        return {"a": self.a, "b": self.b}
149
150    t = MyTrackable()
151    converter = saveable_object_util.SaveableCompatibilityConverter(t)
152    serialized_tensors = converter._serialize_to_tensors()
153
154    self.assertLen(serialized_tensors, 2)
155    self.assertEqual(25, self.evaluate(serialized_tensors["a"].tensor))
156    self.assertIsNone(serialized_tensors["b"].tensor)
157
158    with self.assertRaisesRegex(ValueError, "Could not restore object"):
159      converter._restore_from_tensors({"a": 5.})
160
161    converter._restore_from_tensors({"a": 5., "b": 6.})
162    self.assertEqual(5, self.evaluate(t.a))
163    self.assertEqual(6, self.evaluate(t.b))
164
165
166class _MultiSpecSaveable(saveable_object.SaveableObject):
167
168  def __init__(self, obj, name):
169    self.obj = obj
170    specs = [
171        saveable_object.SaveSpec(obj.a, "", name + "-a"),
172        saveable_object.SaveSpec(obj.b, "", name + "-b")]
173    super(_MultiSpecSaveable, self).__init__(None, specs, name)
174
175  def restore(self, restored_tensors, restored_shapes):
176    del restored_shapes  # Unused.
177    self.obj.a.assign(restored_tensors[0])
178    self.obj.b.assign(restored_tensors[1])
179
180
181class MultipleSpecConverterTest(test.TestCase):
182
183  def test_multiple_specs_single_saveable(self):
184
185    class MyTrackable(base.Trackable):
186
187      def __init__(self):
188        self.a = variables.Variable(35.0)
189        self.b = variables.Variable(40.0)
190
191      def _gather_saveables_for_checkpoint(self):
192        return {"foo": lambda name: _MultiSpecSaveable(self, name)}
193
194    t = MyTrackable()
195    converter = saveable_object_util.SaveableCompatibilityConverter(t)
196    serialized_tensors = converter._serialize_to_tensors()
197
198    self.assertLen(serialized_tensors, 2)
199    self.assertEqual(35, self.evaluate(serialized_tensors["foo-a"]))
200    self.assertEqual(40, self.evaluate(serialized_tensors["foo-b"]))
201    converter._restore_from_tensors({"foo-a": 5., "foo-b": 6.})
202    self.assertEqual(5, self.evaluate(t.a))
203    self.assertEqual(6, self.evaluate(t.b))
204
205    # Make sure that the legacy saveable name has been applied.
206    self.assertEqual("foo", saveable_compat.get_saveable_name(converter))
207
208  def test_multiple_specs_multiple_saveables(self):
209    # This is an edge case not handled by the converter. Should raise an error.
210
211    class MyTrackable(base.Trackable):
212
213      def __init__(self):
214        self.a = variables.Variable(45.0)
215        self.b = variables.Variable(50.0)
216
217      def _gather_saveables_for_checkpoint(self):
218        return {"foo": lambda name: _MultiSpecSaveable(self, name),
219                "bar": lambda name: _MultiSpecSaveable(self, name)}
220
221    t = MyTrackable()
222    with self.assertRaises(saveable_compat.CheckpointConversionError):
223      saveable_object_util.SaveableCompatibilityConverter(t)
224
225
226class State(resource.TrackableResource):
227
228  def __init__(self, initial_value):
229    super().__init__()
230    self._initial_value = initial_value
231    self._initialize()
232
233  def _create_resource(self):
234    return gen_resource_variable_ops.var_handle_op(
235        shape=[],
236        dtype=dtypes.float32,
237        shared_name=context.anonymous_name(),
238        name="StateVar",
239        container="")
240
241  def _initialize(self):
242    gen_resource_variable_ops.assign_variable_op(self.resource_handle,
243                                                 self._initial_value)
244
245  def _destroy_resource(self):
246    gen_resource_variable_ops.destroy_resource_op(self.resource_handle,
247                                                  ignore_lookup_error=True)
248
249  def read(self):
250    return gen_resource_variable_ops.read_variable_op(self.resource_handle,
251                                                      dtypes.float32)
252
253  def assign(self, value):
254    gen_resource_variable_ops.assign_variable_op(self.resource_handle, value)
255
256
257class _StateSaveable(saveable_object.SaveableObject):
258
259  def __init__(self, obj, name):
260    spec = saveable_object.SaveSpec(obj.read(), "", name)
261    self.obj = obj
262    super(_StateSaveable, self).__init__(obj, [spec], name)
263
264  def restore(self, restored_tensors, restored_shapes):
265    del restored_shapes  # Unused.
266    self.obj.assign(restored_tensors[0])
267
268
269class SaveableState(State):
270
271  def _gather_saveables_for_checkpoint(self):
272    return {
273        "value": lambda name: _StateSaveable(self, name)
274    }
275
276
277class TrackableState(State):
278
279  def _serialize_to_tensors(self):
280    return {
281        "value": self.read()
282    }
283
284  def _restore_from_tensors(self, restored_tensors):
285    self.assign(restored_tensors["value"])
286
287
288class SaveableCompatibilityEndToEndTest(test.TestCase):
289
290  def test_checkpoint_comparison(self):
291    saveable_state = SaveableState(5.)
292    trackable_state = TrackableState(10.)
293
294    # First test that SaveableState and TrackableState are equivalent by
295    # saving a checkpoint with both objects and swapping values.
296
297    self.assertEqual(5, self.evaluate(saveable_state.read()))
298    self.assertEqual(10, self.evaluate(trackable_state.read()))
299
300    ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
301    checkpoint.Checkpoint(a=saveable_state, b=trackable_state).write(ckpt_path)
302
303    status = checkpoint.Checkpoint(b=saveable_state,
304                                   a=trackable_state).read(ckpt_path)
305    status.assert_consumed()
306
307    self.assertEqual(10, self.evaluate(saveable_state.read()))
308    self.assertEqual(5, self.evaluate(trackable_state.read()))
309
310    # Test that the converted SaveableState is compatible with the checkpoint
311    # saved above.
312    to_convert = SaveableState(0.0)
313
314    converted_saveable_state = (
315        saveable_object_util.SaveableCompatibilityConverter(to_convert))
316
317    checkpoint.Checkpoint(a=converted_saveable_state).read(
318        ckpt_path).assert_existing_objects_matched().expect_partial()
319    self.assertEqual(5, self.evaluate(to_convert.read()))
320
321    checkpoint.Checkpoint(b=converted_saveable_state).read(
322        ckpt_path).assert_existing_objects_matched().expect_partial()
323    self.assertEqual(10, self.evaluate(to_convert.read()))
324
325
326if __name__ == "__main__":
327  test.main()
328