• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for the functional saver."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from tensorflow.python.eager import context
24from tensorflow.python.eager import remote
25from tensorflow.python.eager import test
26from tensorflow.python.eager import wrap_function
27from tensorflow.python.framework import config
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import test_util
31from tensorflow.python.ops import resource_variable_ops
32from tensorflow.python.platform import gfile
33from tensorflow.python.training import server_lib
34from tensorflow.python.training.saving import checkpoint_options
35from tensorflow.python.training.saving import functional_saver
36from tensorflow.python.training.saving import saveable_hook
37from tensorflow.python.training.saving import saveable_object_util
38
39LOCALHOST = "/job:localhost/replica:0/task:0/device:CPU:0"
40
41
42class SaverTest(test.TestCase):
43
44  def setUp(self):
45    super(SaverTest, self).setUp()
46    cpus = config.list_physical_devices("CPU")
47    # Set 3 virtual CPUs
48    config.set_logical_device_configuration(cpus[0], [
49        context.LogicalDeviceConfiguration(),
50        context.LogicalDeviceConfiguration(),
51        context.LogicalDeviceConfiguration()
52    ])
53    self.local_options = checkpoint_options.CheckpointOptions(
54        experimental_io_device=LOCALHOST)
55
56  @test_util.run_in_graph_and_eager_modes
57  def test_resource_variable(self):
58    v1 = resource_variable_ops.ResourceVariable(2.)
59    self.evaluate(v1.initializer)
60    saver = functional_saver._SingleDeviceSaver(
61        saveable_object_util.saveable_objects_for_op(v1, "x"))
62    prefix = os.path.join(self.get_temp_dir(), "ckpt")
63    self.evaluate(saver.save(constant_op.constant(prefix)))
64    self.assertEqual(2, len(gfile.Glob(prefix + "*")))
65    self.evaluate(v1.assign(1.))
66    self.evaluate(saver.restore(prefix))
67    self.assertEqual(2., self.evaluate(v1))
68
69    v2 = resource_variable_ops.ResourceVariable(3.)
70    self.evaluate(v2.initializer)
71    second_saver = functional_saver._SingleDeviceSaver(
72        saveable_object_util.saveable_objects_for_op(v2, "x"))
73    self.evaluate(second_saver.restore(prefix))
74    self.assertEqual(2., self.evaluate(v2))
75
76  @test_util.run_in_graph_and_eager_modes
77  def test_resource_variable_use_localhost(self):
78    v1 = resource_variable_ops.ResourceVariable(2.)
79    self.evaluate(v1.initializer)
80    saver = functional_saver._SingleDeviceSaver(
81        saveable_object_util.saveable_objects_for_op(v1, "x"))
82    prefix = os.path.join(self.get_temp_dir(), "ckpt")
83    self.evaluate(saver.save(constant_op.constant(prefix), self.local_options))
84    self.assertEqual(2, len(gfile.Glob(prefix + "*")))
85    self.evaluate(v1.assign(1.))
86    self.evaluate(saver.restore(prefix, self.local_options))
87    self.assertEqual(2., self.evaluate(v1))
88
89    v2 = resource_variable_ops.ResourceVariable(3.)
90    self.evaluate(v2.initializer)
91    second_saver = functional_saver._SingleDeviceSaver(
92        saveable_object_util.saveable_objects_for_op(v2, "x"))
93    self.evaluate(second_saver.restore(prefix, self.local_options))
94    self.assertEqual(2., self.evaluate(v2))
95
96    # In graph mode, verify that the save and restore ops were set to run on
97    # localhost.
98    if not context.executing_eagerly():
99      for op in ops.get_default_graph().get_operations():
100        if op.type in ("SaveV2", "RestoreV2"):
101          self.assertEqual(LOCALHOST, op.device)
102
103  def test_to_proto(self):
104    v1 = resource_variable_ops.ResourceVariable(2.)
105    saver = functional_saver.MultiDeviceSaver(
106        saveable_object_util.saveable_objects_for_op(v1, "x"))
107    prefix = os.path.join(self.get_temp_dir(), "ckpt")
108
109    proto_accumulator = []
110    wrapped = wrap_function.wrap_function(
111        lambda: proto_accumulator.append(saver.to_proto()), signature=())
112    self.assertEqual(1, len(proto_accumulator))
113    proto = proto_accumulator[0]
114    save = wrapped.prune(
115        feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name),
116        fetches=wrapped.graph.get_tensor_by_name(proto.save_tensor_name))
117    restore = wrapped.prune(
118        feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name),
119        fetches=wrapped.graph.get_operation_by_name(proto.restore_op_name))
120    save_path = save(constant_op.constant(prefix))
121    v1.assign(1.)
122    restore(constant_op.constant(save_path))
123    self.assertEqual(2., self.evaluate(v1))
124
125    v2 = resource_variable_ops.ResourceVariable(3.)
126    second_saver = functional_saver.MultiDeviceSaver(
127        saveable_object_util.saveable_objects_for_op(v2, "x"))
128    second_saver.restore(save_path)
129    self.assertEqual(2., self.evaluate(v2))
130
131  @test_util.disable_tfrt("b/171765113: server is not supported in TFRT yet.")
132  def test_checkpoint_is_sharded_by_task(self):
133    servers = [server_lib.Server.create_local_server() for _ in range(3)]
134    cluster_spec = server_lib.ClusterSpec({
135        "worker": [s.target[len("grpc://"):] for s in servers]})
136    remote.connect_to_cluster(cluster_spec)
137    with ops.device("/job:worker/task:0/cpu:0"):
138      v0 = resource_variable_ops.ResourceVariable(0.)
139    with ops.device("/job:worker/task:1/cpu:0"):
140      v1 = resource_variable_ops.ResourceVariable(1.)
141    with ops.device("/job:worker/task:2/cpu:0"):
142      v2 = resource_variable_ops.ResourceVariable(2.)
143
144    self.evaluate([v0.initializer, v1.initializer, v2.initializer])
145    saver = functional_saver.MultiDeviceSaver(
146        list(saveable_object_util.saveable_objects_for_op(v0, "v0")) +
147        list(saveable_object_util.saveable_objects_for_op(v1, "v1")) +
148        list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
149    prefix = os.path.join(self.get_temp_dir(), "ckpt")
150    self.evaluate(saver.save(constant_op.constant(prefix)))
151    self.assertEqual(4, len(gfile.Glob(prefix + "*")))
152    self.evaluate(v0.assign(-1.))
153    self.evaluate(v1.assign(-1.))
154    self.evaluate(v2.assign(-1.))
155    self.evaluate(saver.restore(constant_op.constant(prefix)))
156    self.assertEqual(0., self.evaluate(v0))
157    self.assertEqual(1., self.evaluate(v1))
158    self.assertEqual(2., self.evaluate(v2))
159
160  @test_util.run_in_graph_and_eager_modes
161  def test_checkpoint_multi_device_using_localhost(self):
162    with ops.device("cpu:0"):
163      v0 = resource_variable_ops.ResourceVariable(0.)
164    with ops.device("cpu:1"):
165      v1 = resource_variable_ops.ResourceVariable(1.)
166    with ops.device("cpu:2"):
167      v2 = resource_variable_ops.ResourceVariable(2.)
168
169    self.evaluate([v0.initializer, v1.initializer, v2.initializer])
170    saver = functional_saver.MultiDeviceSaver(
171        list(saveable_object_util.saveable_objects_for_op(v0, "v0")) +
172        list(saveable_object_util.saveable_objects_for_op(v1, "v1")) +
173        list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
174    prefix = os.path.join(self.get_temp_dir(), "ckpt")
175    self.evaluate(saver.save(constant_op.constant(prefix), self.local_options))
176    self.assertEqual(2, len(gfile.Glob(prefix + "*")))
177    self.evaluate(v0.assign(-1.))
178    self.evaluate(v1.assign(-1.))
179    self.evaluate(v2.assign(-1.))
180    self.evaluate(
181        saver.restore(constant_op.constant(prefix), self.local_options))
182    self.assertEqual(0., self.evaluate(v0))
183    self.assertEqual(1., self.evaluate(v1))
184    self.assertEqual(2., self.evaluate(v2))
185
186    # In graph mode, verify that the save and restore ops were set to run on
187    # localhost.
188    if not context.executing_eagerly():
189      for op in ops.get_default_graph().get_operations():
190        if op.type in ("SaveV2", "RestoreV2", "MergeV2Checkpoints"):
191          self.assertEqual(LOCALHOST, op.device)
192
193  def test_callbacks_run(self):
194    #  Use dict because an int would be shadowed inside callback.
195    called = {
196        "save": 0,
197        "restore": 0,
198    }
199
200    class DummyHook(saveable_hook.SaveableHook):
201
202      def before_save(self):
203        called["save"] += 1
204
205      def after_restore(self):
206        called["restore"] += 1
207
208    saveable = DummyHook(name="dummy")
209
210    saver = functional_saver.MultiDeviceSaver([saveable])
211    prefix = os.path.join(self.get_temp_dir(), "ckpt")
212
213    self.evaluate(saver.save(constant_op.constant(prefix)))
214    self.assertEqual({"save": 1, "restore": 0}, called)
215
216    self.evaluate(saver.restore(prefix))
217    self.assertEqual({"save": 1, "restore": 1}, called)
218
219
220if __name__ == "__main__":
221  ops.enable_eager_execution()
222  test.main()
223