1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for the functional saver.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from tensorflow.python.eager import context 24from tensorflow.python.eager import remote 25from tensorflow.python.eager import test 26from tensorflow.python.eager import wrap_function 27from tensorflow.python.framework import config 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import test_util 31from tensorflow.python.ops import resource_variable_ops 32from tensorflow.python.platform import gfile 33from tensorflow.python.training import server_lib 34from tensorflow.python.training.saving import checkpoint_options 35from tensorflow.python.training.saving import functional_saver 36from tensorflow.python.training.saving import saveable_hook 37from tensorflow.python.training.saving import saveable_object_util 38 39LOCALHOST = "/job:localhost/replica:0/task:0/device:CPU:0" 40 41 42class SaverTest(test.TestCase): 43 44 def setUp(self): 45 super(SaverTest, self).setUp() 46 cpus = config.list_physical_devices("CPU") 47 # Set 3 virtual CPUs 48 config.set_logical_device_configuration(cpus[0], [ 49 context.LogicalDeviceConfiguration(), 50 context.LogicalDeviceConfiguration(), 51 context.LogicalDeviceConfiguration() 52 ]) 53 self.local_options = checkpoint_options.CheckpointOptions( 54 experimental_io_device=LOCALHOST) 55 56 @test_util.run_in_graph_and_eager_modes 57 def test_resource_variable(self): 58 v1 = resource_variable_ops.ResourceVariable(2.) 59 self.evaluate(v1.initializer) 60 saver = functional_saver._SingleDeviceSaver( 61 saveable_object_util.saveable_objects_for_op(v1, "x")) 62 prefix = os.path.join(self.get_temp_dir(), "ckpt") 63 self.evaluate(saver.save(constant_op.constant(prefix))) 64 self.assertEqual(2, len(gfile.Glob(prefix + "*"))) 65 self.evaluate(v1.assign(1.)) 66 self.evaluate(saver.restore(prefix)) 67 self.assertEqual(2., self.evaluate(v1)) 68 69 v2 = resource_variable_ops.ResourceVariable(3.) 70 self.evaluate(v2.initializer) 71 second_saver = functional_saver._SingleDeviceSaver( 72 saveable_object_util.saveable_objects_for_op(v2, "x")) 73 self.evaluate(second_saver.restore(prefix)) 74 self.assertEqual(2., self.evaluate(v2)) 75 76 @test_util.run_in_graph_and_eager_modes 77 def test_resource_variable_use_localhost(self): 78 v1 = resource_variable_ops.ResourceVariable(2.) 79 self.evaluate(v1.initializer) 80 saver = functional_saver._SingleDeviceSaver( 81 saveable_object_util.saveable_objects_for_op(v1, "x")) 82 prefix = os.path.join(self.get_temp_dir(), "ckpt") 83 self.evaluate(saver.save(constant_op.constant(prefix), self.local_options)) 84 self.assertEqual(2, len(gfile.Glob(prefix + "*"))) 85 self.evaluate(v1.assign(1.)) 86 self.evaluate(saver.restore(prefix, self.local_options)) 87 self.assertEqual(2., self.evaluate(v1)) 88 89 v2 = resource_variable_ops.ResourceVariable(3.) 90 self.evaluate(v2.initializer) 91 second_saver = functional_saver._SingleDeviceSaver( 92 saveable_object_util.saveable_objects_for_op(v2, "x")) 93 self.evaluate(second_saver.restore(prefix, self.local_options)) 94 self.assertEqual(2., self.evaluate(v2)) 95 96 # In graph mode, verify that the save and restore ops were set to run on 97 # localhost. 98 if not context.executing_eagerly(): 99 for op in ops.get_default_graph().get_operations(): 100 if op.type in ("SaveV2", "RestoreV2"): 101 self.assertEqual(LOCALHOST, op.device) 102 103 def test_to_proto(self): 104 v1 = resource_variable_ops.ResourceVariable(2.) 105 saver = functional_saver.MultiDeviceSaver( 106 saveable_object_util.saveable_objects_for_op(v1, "x")) 107 prefix = os.path.join(self.get_temp_dir(), "ckpt") 108 109 proto_accumulator = [] 110 wrapped = wrap_function.wrap_function( 111 lambda: proto_accumulator.append(saver.to_proto()), signature=()) 112 self.assertEqual(1, len(proto_accumulator)) 113 proto = proto_accumulator[0] 114 save = wrapped.prune( 115 feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name), 116 fetches=wrapped.graph.get_tensor_by_name(proto.save_tensor_name)) 117 restore = wrapped.prune( 118 feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name), 119 fetches=wrapped.graph.get_operation_by_name(proto.restore_op_name)) 120 save_path = save(constant_op.constant(prefix)) 121 v1.assign(1.) 122 restore(constant_op.constant(save_path)) 123 self.assertEqual(2., self.evaluate(v1)) 124 125 v2 = resource_variable_ops.ResourceVariable(3.) 126 second_saver = functional_saver.MultiDeviceSaver( 127 saveable_object_util.saveable_objects_for_op(v2, "x")) 128 second_saver.restore(save_path) 129 self.assertEqual(2., self.evaluate(v2)) 130 131 @test_util.disable_tfrt("b/171765113: server is not supported in TFRT yet.") 132 def test_checkpoint_is_sharded_by_task(self): 133 servers = [server_lib.Server.create_local_server() for _ in range(3)] 134 cluster_spec = server_lib.ClusterSpec({ 135 "worker": [s.target[len("grpc://"):] for s in servers]}) 136 remote.connect_to_cluster(cluster_spec) 137 with ops.device("/job:worker/task:0/cpu:0"): 138 v0 = resource_variable_ops.ResourceVariable(0.) 139 with ops.device("/job:worker/task:1/cpu:0"): 140 v1 = resource_variable_ops.ResourceVariable(1.) 141 with ops.device("/job:worker/task:2/cpu:0"): 142 v2 = resource_variable_ops.ResourceVariable(2.) 143 144 self.evaluate([v0.initializer, v1.initializer, v2.initializer]) 145 saver = functional_saver.MultiDeviceSaver( 146 list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + 147 list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + 148 list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) 149 prefix = os.path.join(self.get_temp_dir(), "ckpt") 150 self.evaluate(saver.save(constant_op.constant(prefix))) 151 self.assertEqual(4, len(gfile.Glob(prefix + "*"))) 152 self.evaluate(v0.assign(-1.)) 153 self.evaluate(v1.assign(-1.)) 154 self.evaluate(v2.assign(-1.)) 155 self.evaluate(saver.restore(constant_op.constant(prefix))) 156 self.assertEqual(0., self.evaluate(v0)) 157 self.assertEqual(1., self.evaluate(v1)) 158 self.assertEqual(2., self.evaluate(v2)) 159 160 @test_util.run_in_graph_and_eager_modes 161 def test_checkpoint_multi_device_using_localhost(self): 162 with ops.device("cpu:0"): 163 v0 = resource_variable_ops.ResourceVariable(0.) 164 with ops.device("cpu:1"): 165 v1 = resource_variable_ops.ResourceVariable(1.) 166 with ops.device("cpu:2"): 167 v2 = resource_variable_ops.ResourceVariable(2.) 168 169 self.evaluate([v0.initializer, v1.initializer, v2.initializer]) 170 saver = functional_saver.MultiDeviceSaver( 171 list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + 172 list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + 173 list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) 174 prefix = os.path.join(self.get_temp_dir(), "ckpt") 175 self.evaluate(saver.save(constant_op.constant(prefix), self.local_options)) 176 self.assertEqual(2, len(gfile.Glob(prefix + "*"))) 177 self.evaluate(v0.assign(-1.)) 178 self.evaluate(v1.assign(-1.)) 179 self.evaluate(v2.assign(-1.)) 180 self.evaluate( 181 saver.restore(constant_op.constant(prefix), self.local_options)) 182 self.assertEqual(0., self.evaluate(v0)) 183 self.assertEqual(1., self.evaluate(v1)) 184 self.assertEqual(2., self.evaluate(v2)) 185 186 # In graph mode, verify that the save and restore ops were set to run on 187 # localhost. 188 if not context.executing_eagerly(): 189 for op in ops.get_default_graph().get_operations(): 190 if op.type in ("SaveV2", "RestoreV2", "MergeV2Checkpoints"): 191 self.assertEqual(LOCALHOST, op.device) 192 193 def test_callbacks_run(self): 194 # Use dict because an int would be shadowed inside callback. 195 called = { 196 "save": 0, 197 "restore": 0, 198 } 199 200 class DummyHook(saveable_hook.SaveableHook): 201 202 def before_save(self): 203 called["save"] += 1 204 205 def after_restore(self): 206 called["restore"] += 1 207 208 saveable = DummyHook(name="dummy") 209 210 saver = functional_saver.MultiDeviceSaver([saveable]) 211 prefix = os.path.join(self.get_temp_dir(), "ckpt") 212 213 self.evaluate(saver.save(constant_op.constant(prefix))) 214 self.assertEqual({"save": 1, "restore": 0}, called) 215 216 self.evaluate(saver.restore(prefix)) 217 self.assertEqual({"save": 1, "restore": 1}, called) 218 219 220if __name__ == "__main__": 221 ops.enable_eager_execution() 222 test.main() 223