• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for base_delegate."""
16import os
17from tensorflow.python.checkpoint import checkpoint as util
18from tensorflow.python.eager import test
19from tensorflow.python.framework import test_util
20from tensorflow.python.ops import variables as variables_lib
21from tensorflow.python.saved_model import load
22from tensorflow.python.saved_model import save
23from tensorflow.python.trackable import base
24from tensorflow.python.trackable import base_delegate
25
26
27class Inner(base.Trackable):
28
29  def __init__(self, v):
30    self.v = v
31    self._track_trackable(v, "v")
32
33
34class Wrapper(base_delegate.DelegatingTrackableMixin, base.Trackable):
35
36  def __init__(self, inner):
37    self.inner = inner
38    super(Wrapper, self).__init__(inner)
39
40  @property
41  def v(self):
42    return self.inner.v
43
44
45@test_util.run_all_in_graph_and_eager_modes
46class BaseDelegateTest(test.TestCase):
47
48  def test_checkpoint(self):
49    a = Wrapper(Inner(variables_lib.Variable(15.0)))
50    b = Wrapper(Inner(variables_lib.Variable(-15.0)))
51    self.evaluate([a.v.initializer, b.v.initializer])
52
53    test_dir = self.get_temp_dir()
54    prefix = os.path.join(test_dir, "ckpt")
55    ckpt = util.Checkpoint(a=a, b=b)
56    prefix_tensor = ckpt.save(prefix)
57
58    self.assertEqual([15, -15], self.evaluate([a.v, b.v]))
59    self.evaluate(a.v.assign(-3))
60    self.evaluate(b.v.assign(12))
61    self.assertEqual([-3, 12], self.evaluate([a.v, b.v]))
62
63    # Test that the model can be saved with the wrapper and loaded without it.
64    ckpt2 = util.Checkpoint(a=a.inner, b=b.inner)
65    ckpt2.restore(prefix_tensor).assert_consumed().run_restore_ops()
66    self.assertEqual([15, -15], self.evaluate([a.v, b.v]))
67
68  def test_saved_model(self):
69    a = Wrapper(Inner(variables_lib.Variable(-15.0)))
70    self.evaluate([a.v.initializer])
71    self.assertEqual([-15], self.evaluate([a.v]))
72
73    test_dir = self.get_temp_dir()
74    saved_model_path = os.path.join(test_dir, "saved_model")
75    save.save(a, saved_model_path)
76
77    loaded = load.load(saved_model_path)
78    self.evaluate([loaded.v.initializer])
79    self.assertEqual([-15], self.evaluate([loaded.v]))
80
81
82if __name__ == "__main__":
83  test.main()
84