1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15from __future__ import absolute_import 16from __future__ import division 17from __future__ import print_function 18 19import os 20 21from tensorflow.contrib.checkpoint.python import split_dependency 22from tensorflow.python.eager import test 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import resource_variable_ops 26from tensorflow.python.training.tracking import base 27from tensorflow.python.training.tracking import tracking 28from tensorflow.python.training.tracking import util 29 30 31def _split_variable_closure(variable): 32 def _fill_save_buffer_fn(save_buffer): 33 save_buffer["first_half"] = variable[:2] 34 save_buffer["second_half"] = variable[2:] 35 return _fill_save_buffer_fn 36 37 38def _combine_variable_closure(variable): 39 def _consume_restore_buffer_fn(restore_buffer): 40 return variable.assign( 41 array_ops.concat([restore_buffer["first_half"], 42 restore_buffer["second_half"]], 43 axis=0)) 44 return _consume_restore_buffer_fn 45 46 47class SaveTensorSlicesAsDeps(base.Trackable): 48 49 def __init__(self): 50 self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) 51 split_dependencies = split_dependency.split_dependency( 52 component_names=("first_half", "second_half"), 53 component_dtypes=(self.combined.dtype,) * 2, 54 fill_save_buffer_fn=_split_variable_closure( 55 self.combined), 56 consume_restore_buffer_fn=_combine_variable_closure( 57 self.combined)) 58 for name, dep in split_dependencies.items(): 59 self._track_trackable(dep, name=name) 60 61 62class HasRegularDeps(tracking.AutoTrackable): 63 64 def __init__(self): 65 self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) 66 self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) 67 68 69class OnlyOneDep(tracking.AutoTrackable): 70 71 def __init__(self): 72 self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) 73 74 75class SplitTests(test.TestCase): 76 77 @test_util.run_in_graph_and_eager_modes 78 def testSaveRestoreSplitDep(self): 79 save_checkpoint = util.Checkpoint( 80 dep=SaveTensorSlicesAsDeps()) 81 self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.])) 82 checkpoint_directory = self.get_temp_dir() 83 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 84 save_path = save_checkpoint.save(checkpoint_prefix) 85 86 regular_deps = HasRegularDeps() 87 regular_restore_checkpoint = util.Checkpoint( 88 dep=regular_deps) 89 regular_restore_checkpoint.restore( 90 save_path).assert_consumed().run_restore_ops() 91 self.assertAllEqual([1., 2.], self.evaluate(regular_deps.first_half)) 92 self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half)) 93 94 one_dep = OnlyOneDep() 95 one_dep_restore_checkpoint = util.Checkpoint(dep=one_dep) 96 status = one_dep_restore_checkpoint.restore(save_path) 97 with self.assertRaises(AssertionError): 98 # Missing the second dependency. 99 status.assert_consumed() 100 status.run_restore_ops() 101 self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half)) 102 103 restore_checkpoint = util.Checkpoint() 104 status = restore_checkpoint.restore(save_path) 105 restore_checkpoint.dep = SaveTensorSlicesAsDeps() 106 status.assert_consumed().run_restore_ops() 107 self.assertAllEqual( 108 [1., 2., 3., 4.], 109 self.evaluate(restore_checkpoint.dep.combined)) 110 111 112if __name__ == "__main__": 113 test.main() 114