1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for base_delegate.""" 16import os 17from tensorflow.python.checkpoint import checkpoint as util 18from tensorflow.python.eager import test 19from tensorflow.python.framework import test_util 20from tensorflow.python.ops import variables as variables_lib 21from tensorflow.python.saved_model import load 22from tensorflow.python.saved_model import save 23from tensorflow.python.trackable import base 24from tensorflow.python.trackable import base_delegate 25 26 27class Inner(base.Trackable): 28 29 def __init__(self, v): 30 self.v = v 31 self._track_trackable(v, "v") 32 33 34class Wrapper(base_delegate.DelegatingTrackableMixin, base.Trackable): 35 36 def __init__(self, inner): 37 self.inner = inner 38 super(Wrapper, self).__init__(inner) 39 40 @property 41 def v(self): 42 return self.inner.v 43 44 45@test_util.run_all_in_graph_and_eager_modes 46class BaseDelegateTest(test.TestCase): 47 48 def test_checkpoint(self): 49 a = Wrapper(Inner(variables_lib.Variable(15.0))) 50 b = Wrapper(Inner(variables_lib.Variable(-15.0))) 51 self.evaluate([a.v.initializer, b.v.initializer]) 52 53 test_dir = self.get_temp_dir() 54 prefix = os.path.join(test_dir, "ckpt") 55 ckpt = util.Checkpoint(a=a, b=b) 56 prefix_tensor = ckpt.save(prefix) 57 58 self.assertEqual([15, -15], self.evaluate([a.v, b.v])) 59 self.evaluate(a.v.assign(-3)) 60 self.evaluate(b.v.assign(12)) 61 self.assertEqual([-3, 12], self.evaluate([a.v, b.v])) 62 63 # Test that the model can be saved with the wrapper and loaded without it. 64 ckpt2 = util.Checkpoint(a=a.inner, b=b.inner) 65 ckpt2.restore(prefix_tensor).assert_consumed().run_restore_ops() 66 self.assertEqual([15, -15], self.evaluate([a.v, b.v])) 67 68 def test_saved_model(self): 69 a = Wrapper(Inner(variables_lib.Variable(-15.0))) 70 self.evaluate([a.v.initializer]) 71 self.assertEqual([-15], self.evaluate([a.v])) 72 73 test_dir = self.get_temp_dir() 74 saved_model_path = os.path.join(test_dir, "saved_model") 75 save.save(a, saved_model_path) 76 77 loaded = load.load(saved_model_path) 78 self.evaluate([loaded.v.initializer]) 79 self.assertEqual([-15], self.evaluate([loaded.v])) 80 81 82if __name__ == "__main__": 83 test.main() 84