• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for tensorflow.python.training.saver.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import glob
22import math
23import os
24import random
25import time
26
27import numpy as np
28import six
29
30from google.protobuf.any_pb2 import Any
31
32from tensorflow.core.protobuf import config_pb2
33from tensorflow.core.protobuf import meta_graph_pb2
34from tensorflow.core.protobuf import queue_runner_pb2
35from tensorflow.core.protobuf import rewriter_config_pb2
36from tensorflow.core.protobuf import saver_pb2
37from tensorflow.python.client import session
38from tensorflow.python.data.ops import dataset_ops
39from tensorflow.python.data.ops import iterator_ops
40from tensorflow.python.eager import context
41from tensorflow.python.framework import constant_op
42from tensorflow.python.framework import dtypes
43from tensorflow.python.framework import errors
44from tensorflow.python.framework import errors_impl
45from tensorflow.python.framework import function
46from tensorflow.python.framework import graph_io
47from tensorflow.python.framework import meta_graph
48from tensorflow.python.framework import ops as ops_lib
49from tensorflow.python.framework import test_util
50from tensorflow.python.lib.io import file_io
51from tensorflow.python.ops import array_ops
52from tensorflow.python.ops import control_flow_ops
53from tensorflow.python.ops import data_flow_ops
54from tensorflow.python.ops import gradients_impl
55from tensorflow.python.ops import math_ops
56from tensorflow.python.ops import nn_ops
57from tensorflow.python.ops import partitioned_variables
58from tensorflow.python.ops import random_ops
59from tensorflow.python.ops import resource_variable_ops
60from tensorflow.python.ops import sparse_ops
61from tensorflow.python.ops import variable_scope
62from tensorflow.python.ops import variables
63import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
64from tensorflow.python.platform import gfile
65from tensorflow.python.platform import test
66from tensorflow.python.summary import summary
67from tensorflow.python.training import adam
68from tensorflow.python.training import checkpoint_management
69from tensorflow.python.training import gradient_descent
70from tensorflow.python.training import py_checkpoint_reader
71from tensorflow.python.training import queue_runner_impl
72from tensorflow.python.training import saver as saver_module
73from tensorflow.python.training import saver_test_utils
74from tensorflow.python.training.tracking import base as trackable_base
75from tensorflow.python.util import compat
76
77
78class SaverTest(test.TestCase):
79
80  def basicSaveRestore(self, variable_op):
81    save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")
82
83    with self.session(graph=ops_lib.Graph()) as sess:
84      # Build a graph with 2 parameter nodes, and Save and
85      # Restore nodes for them.
86      v0 = variable_op(10.0, name="v0")
87      v1 = variable_op(20.0, name="v1")
88      v2 = saver_test_utils.CheckpointedOp(name="v2")
89      v2_init = v2.insert("k1", 30.0)
90
91      # Initialize all variables
92      if not context.executing_eagerly():
93        self.evaluate([variables.global_variables_initializer(), v2_init])
94
95        # Check that the parameter nodes have been initialized.
96      self.assertEqual(10.0, self.evaluate(v0))
97      self.assertEqual(20.0, self.evaluate(v1))
98      self.assertEqual(b"k1", self.evaluate(v2.keys()))
99      self.assertEqual(30.0, self.evaluate(v2.values()))
100
101      # Save the initialized values in the file at "save_path"
102      save = saver_module.Saver(
103          {
104              "v0": v0,
105              "v1": v1,
106              "v2": v2.saveable
107          }, restore_sequentially=True)
108      val = save.save(sess, save_path)
109      self.assertTrue(isinstance(val, six.string_types))
110      self.assertEqual(save_path, val)
111
112    # Start a second session.  In that session the parameter nodes
113    # have not been initialized either.
114    with self.session(graph=ops_lib.Graph()) as sess:
115      v0 = variable_op(-1.0, name="v0")
116      v1 = variable_op(-1.0, name="v1")
117      v2 = saver_test_utils.CheckpointedOp(name="v2")
118
119      # Assert that the variables are not initialized.
120      if not context.executing_eagerly():
121        self.assertEqual(
122            len(variables.report_uninitialized_variables().eval()), 2)
123        self.assertEqual(0, len(self.evaluate(v2.keys())))
124        self.assertEqual(0, len(self.evaluate(v2.values())))
125      # Restore the saved values in the parameter nodes.
126      save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
127      save.restore(sess, save_path)
128      # Check that the parameter nodes have been restored.
129      self.assertEqual(10.0, self.evaluate(v0))
130      self.assertEqual(20.0, self.evaluate(v1))
131      self.assertEqual(b"k1", self.evaluate(v2.keys()))
132      self.assertEqual(30.0, self.evaluate(v2.values()))
133
134    # Build another graph with 2 nodes, initialized
135    # differently, and a Restore node for them.
136    with self.session(graph=ops_lib.Graph()) as sess:
137      v0_2 = variable_op(1000.0, name="v0")
138      v1_2 = variable_op(2000.0, name="v1")
139      v2_2 = saver_test_utils.CheckpointedOp(name="v2")
140      v2_init = v2_2.insert("k1000", 3000.0)
141
142      # Check that the parameter nodes have been initialized.
143      if not context.executing_eagerly():
144        init_all_op = [variables.global_variables_initializer(), v2_init]
145        self.evaluate(init_all_op)
146        # TODO(xpan): Why _mutable_hash_table_v2 doesn't create empty
147        # table as it claims in eager mode?
148        self.assertEqual(b"k1000", self.evaluate(v2_2.keys()))
149        self.assertEqual(3000.0, self.evaluate(v2_2.values()))
150      self.assertEqual(1000.0, self.evaluate(v0_2))
151      self.assertEqual(2000.0, self.evaluate(v1_2))
152
153      # Restore the values saved earlier in the parameter nodes.
154      save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable})
155      save2.restore(sess, save_path)
156      # Check that the parameter nodes have been restored.
157      self.assertEqual(10.0, self.evaluate(v0_2))
158      self.assertEqual(20.0, self.evaluate(v1_2))
159      self.assertEqual(b"k1", self.evaluate(v2_2.keys()))
160      self.assertEqual(30.0, self.evaluate(v2_2.values()))
161
162  def testBasic(self):
163    self.basicSaveRestore(variables.Variable)
164
165  @test_util.run_in_graph_and_eager_modes
166  def testResourceBasic(self):
167    self.basicSaveRestore(resource_variable_ops.ResourceVariable)
168
169  def testResourceColocation(self):
170    # train.Saver is V1 only API.
171    with ops_lib.Graph().as_default():
172      partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2)
173      with ops_lib.device("/job:ps/device:GPU:0"):
174        v = variable_scope.get_variable(
175            "v0", shape=[10, 2], partitioner=partitioner, use_resource=True)
176      saver_module.Saver({"v0": v}).build()
177      save_op = None
178      for op in ops_lib.get_default_graph().get_operations():
179        if op.type == "SaveV2":
180          save_op = op
181          break
182      assert save_op is not None
183      for save_inp in save_op.inputs[3:]:
184        # Input to SaveV2 op is placed on CPU of the same device as
185        # the Variable.
186        self.assertEqual("/job:ps/device:CPU:0", save_inp.device)
187
188  def testResourceVariableReadOpsAddedDeterministically(self):
189    graph_defs = []
190    num_graphs = 10
191    for _ in range(num_graphs):
192      with ops_lib.Graph().as_default() as g:
193        for i in range(20):
194          resource_variable_ops.ResourceVariable(i, name="var%s" % i)
195        saver_module.Saver()
196        graph_defs.append(g.as_graph_def())
197    for i in range(num_graphs - 1):
198      self.assertEqual(graph_defs[i], graph_defs[i + 1])
199
200  def testEagerBasic(self):
201    with context.eager_mode():
202      ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt")
203
204      v1 = resource_variable_ops.ResourceVariable(3.14, name="v1")
205      v2 = resource_variable_ops.ResourceVariable([1, 2], name="v2")
206      save = saver_module.Saver([v1, v2])
207      save.save(None, ckpt_prefix)
208
209      v1.assign(0.0)
210      v2.assign([0, 0])
211      self.assertNear(0.0, self.evaluate(v1), 1e-5)
212      self.assertAllEqual([0, 0], self.evaluate(v2))
213
214      save.restore(None, ckpt_prefix)
215      self.assertNear(3.14, self.evaluate(v1), 1e-5)
216      self.assertAllEqual([1, 2], self.evaluate(v2))
217
218  def testEagerGraphCompatibility(self):
219    # Save from graph mode and restore from eager mode.
220    graph_ckpt_prefix = os.path.join(self.get_temp_dir(), "graph_ckpt")
221    with context.graph_mode():
222      with self.session(graph=ops_lib.Graph()) as sess:
223        # Create a graph model and save the checkpoint.
224        w1 = resource_variable_ops.ResourceVariable(1.0, name="w1")
225        w2 = resource_variable_ops.ResourceVariable(2.0, name="w2")
226        graph_saver = saver_module.Saver([w1, w2])
227        self.evaluate(variables.global_variables_initializer())
228        graph_saver.save(sess, graph_ckpt_prefix)
229
230    with context.eager_mode():
231      ops_lib._default_graph_stack.reset()  # pylint: disable=protected-access
232      ops_lib.reset_default_graph()
233
234      w1 = resource_variable_ops.ResourceVariable(0.0, name="w1")
235      w2 = resource_variable_ops.ResourceVariable(0.0, name="w2")
236
237      graph_saver = saver_module.Saver([w1, w2])
238      graph_saver.restore(None, graph_ckpt_prefix)
239
240      self.assertAllEqual(self.evaluate(w1), 1.0)
241      self.assertAllEqual(self.evaluate(w2), 2.0)
242
243    # Save from eager mode and restore from graph mode.
244    eager_ckpt_prefix = os.path.join(self.get_temp_dir(), "eager_ckpt")
245    with context.eager_mode():
246      ops_lib._default_graph_stack.reset()  # pylint: disable=protected-access
247      ops_lib.reset_default_graph()
248
249      w3 = resource_variable_ops.ResourceVariable(3.0, name="w3")
250      w4 = resource_variable_ops.ResourceVariable(4.0, name="w4")
251
252      graph_saver = saver_module.Saver([w3, w4])
253      graph_saver.save(None, eager_ckpt_prefix)
254
255    with context.graph_mode():
256      with self.session(graph=ops_lib.Graph()) as sess:
257        w3 = resource_variable_ops.ResourceVariable(0.0, name="w3")
258        w4 = resource_variable_ops.ResourceVariable(0.0, name="w4")
259        graph_saver = saver_module.Saver([w3, w4])
260        self.evaluate(variables.global_variables_initializer())
261        graph_saver.restore(sess, eager_ckpt_prefix)
262        self.assertAllEqual(w3, 3.0)
263        self.assertAllEqual(w4, 4.0)
264
265  @test_util.run_in_graph_and_eager_modes
266  def testResourceSaveRestoreCachingDevice(self):
267    save_path = os.path.join(self.get_temp_dir(), "resource_cache")
268    with self.session(graph=ops_lib.Graph()) as sess:
269      v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0",
270                                                 name="v")
271      if context.executing_eagerly():
272        sess = None
273      else:
274        self.evaluate(variables.global_variables_initializer())
275      save = saver_module.Saver([v])
276      save.save(sess, save_path)
277
278      save2 = saver_module.Saver([v])
279      save2.restore(sess, save_path)
280      self.assertEqual(self.evaluate(v), [1])
281
282  def testNoAdditionalOpsAddedBySaverForResourceVariablesOutsideSaveScope(self):
283    with ops_lib.Graph().as_default() as g:
284      v = resource_variable_ops.ResourceVariable(1.0, name="v")
285      with ops_lib.name_scope("saver1"):
286        saver_module.Saver()
287      with ops_lib.name_scope("saver2"):
288        saver_module.Saver({"name": v})
289    ops_in_saver1_scope_but_not_save_scope = [
290        op for op in g.get_operations()
291        if (op.name.startswith("saver1/") and
292            not op.name.startswith("saver1/save/"))]
293    self.assertEqual(ops_in_saver1_scope_but_not_save_scope, [])
294    ops_in_saver2_scope_but_not_save_scope = [
295        op for op in g.get_operations()
296        if (op.name.startswith("saver2/") and
297            not op.name.startswith("saver2/save/"))]
298    self.assertEqual(ops_in_saver2_scope_but_not_save_scope, [])
299
300  def testSaveCopyRestoreWithSaveRelativePaths(self):
301    """Save, copy checkpoint dir and restore from copied dir.
302
303    This only works for save_relative_paths=True.
304    """
305    save_dir1 = os.path.join(self.get_temp_dir(), "save_dir1")
306    os.mkdir(save_dir1)
307    save_path1 = os.path.join(save_dir1, "save_copy_restore")
308
309    # train.Saver is V1 only API.
310    with ops_lib.Graph().as_default():
311      # Build a graph with 2 parameter nodes, and Save and
312      # Restore nodes for them.
313      v0 = variables.VariableV1(10.0, name="v0")
314      v1 = variables.VariableV1(20.0, name="v1")
315      v2 = saver_test_utils.CheckpointedOp(name="v2")
316      v2_init = v2.insert("k1", 30.0)
317      save = saver_module.Saver(
318          var_list={
319              "v0": v0,
320              "v1": v1,
321              "v2": v2.saveable
322          },
323          restore_sequentially=True,
324          save_relative_paths=True)
325      init_all_op = [variables.global_variables_initializer(), v2_init]
326
327      with self.cached_session() as sess:
328        # Initialize all variables
329        self.evaluate(init_all_op)
330
331        # Check that the parameter nodes have been initialized.
332        self.assertEqual(10.0, self.evaluate(v0))
333        self.assertEqual(20.0, self.evaluate(v1))
334        self.assertEqual(b"k1", self.evaluate(v2.keys()))
335        self.assertEqual(30.0, self.evaluate(v2.values()))
336
337        # Save the initialized values in the file at "save_path"
338        val = save.save(sess, save_path1)
339        self.assertTrue(isinstance(val, six.string_types))
340        self.assertEqual(save_path1, val)
341
342      self.assertEqual(
343          checkpoint_management.latest_checkpoint(save_dir1), save_path1)
344      save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2")
345      os.renames(save_dir1, save_dir2)
346      save_path2 = os.path.join(save_dir2, "save_copy_restore")
347      self.assertEqual(
348          checkpoint_management.latest_checkpoint(save_dir2), save_path2)
349
350      # Start a second session.  In that session the parameter nodes
351      # have not been initialized either.
352      with self.cached_session() as sess:
353        v0 = variables.VariableV1(-1.0, name="v0")
354        v1 = variables.VariableV1(-1.0, name="v1")
355        v2 = saver_test_utils.CheckpointedOp(name="v2")
356        save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
357
358        # Assert that the variables are not initialized.
359        self.assertEqual(
360            len(variables.report_uninitialized_variables().eval()), 2)
361        self.assertEqual(0, len(self.evaluate(v2.keys())))
362        self.assertEqual(0, len(self.evaluate(v2.values())))
363
364        # Restore the saved values in the parameter nodes.
365        save.restore(sess, save_path2)
366        # Check that the parameter nodes have been restored.
367        self.assertEqual(10.0, self.evaluate(v0))
368        self.assertEqual(20.0, self.evaluate(v1))
369        self.assertEqual(b"k1", self.evaluate(v2.keys()))
370        self.assertEqual(30.0, self.evaluate(v2.values()))
371
372  def testFilenameTensor(self):
373    # train.Saver is V1 only API.
374    with ops_lib.Graph().as_default():
375      v0 = variables.VariableV1(0, name="v0")
376      filename = b"somerandomfilename"
377      save = saver_module.Saver({"v0": v0}, filename=filename)
378      with self.cached_session() as sess:
379        tensor = sess.graph.get_tensor_by_name(
380            save.saver_def.filename_tensor_name)
381        self.assertEqual(self.evaluate(tensor), filename)
382
383  def testInvalidPath(self):
384    v0 = variables.VariableV1(0, name="v0")
385    for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
386      with self.cached_session() as sess:
387        save = saver_module.Saver({"v0": v0}, write_version=ver)
388        with self.assertRaisesRegex(
389            ValueError, "The passed save_path is not a valid checkpoint:"):
390          save.restore(sess, "invalid path")
391
392  @test_util.run_v1_only("train.Saver is V1 only API.")
393  def testInt64(self):
394    save_path = os.path.join(self.get_temp_dir(), "int64")
395
396    with self.cached_session() as sess:
397      # Build a graph with 1 node, and save and restore for them.
398      v = variables.VariableV1(np.int64(15), name="v")
399      save = saver_module.Saver({"v": v}, restore_sequentially=True)
400      self.evaluate(variables.global_variables_initializer())
401
402      # Save the initialized values in the file at "save_path"
403      val = save.save(sess, save_path)
404      self.assertTrue(isinstance(val, six.string_types))
405      self.assertEqual(save_path, val)
406
407      with self.cached_session() as sess:
408        v = variables.VariableV1(np.int64(-1), name="v")
409        save = saver_module.Saver({"v": v})
410
411      with self.assertRaisesWithPredicateMatch(
412          errors_impl.OpError, lambda e: "uninitialized value v" in e.message):
413        self.evaluate(v)
414
415      # Restore the saved values in the parameter nodes.
416      save.restore(sess, save_path)
417      # Check that the parameter nodes have been restored.
418      self.assertEqual(np.int64(15), self.evaluate(v))
419
420  def testSomeErrors(self):
421    with ops_lib.Graph().as_default():
422      v0 = variables.VariableV1([10.0], name="v0")
423      v1 = variables.VariableV1([20.0], name="v1")
424      v2 = variables.VariableV1([20.0], name="v2")
425      v2._set_save_slice_info(
426          variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))
427
428      # By default the name used for "v2" will be "v1" and raise an error.
429      with self.assertRaisesRegex(ValueError, "same name: v1"):
430        saver_module.Saver([v0, v1, v2])
431
432      # The names are different and will work.
433      saver_module.Saver({"vee1": v1, "other": [v2]})
434
435      # Partitioned variables also cause name conflicts.
436      p_v1 = variable_scope.get_variable(
437          "p_v1",
438          shape=[4, 5],
439          partitioner=partitioned_variables.fixed_size_partitioner(
440              num_shards=2))
441      p_v2 = variable_scope.get_variable(
442          "p_v2",
443          shape=[4, 5],
444          partitioner=partitioned_variables.fixed_size_partitioner(
445              num_shards=2))
446      p_v2._name = "p_v1"
447      with self.assertRaisesRegex(ValueError, "same name: p_v1"):
448        saver_module.Saver([p_v1, p_v2])
449
450  def testSameName(self):
451    with ops_lib.Graph().as_default():
452      v0 = variables.VariableV1([10.0], name="v0")
453      v2 = saver_test_utils.CheckpointedOp(name="v2")
454
455      # Saving one variable under two names raises an error.
456      with self.assertRaisesRegex(
457          ValueError, "The same saveable will be restored with two names: v0"):
458        saver_module.Saver({"v0": v0, "v0too": v0})
459
460      # Ditto for custom saveables.
461      with self.assertRaisesRegex(
462          ValueError, "The same saveable will be restored with two names: v2"):
463        saver_module.Saver({"v2": v2.saveable, "v2too": v2.saveable})
464
465      # Verify non-duplicate names work.
466      saver_module.Saver({"v0": v0, "v2": v2.saveable})
467
468  @test_util.run_v1_only("train.Saver and VariableV1 are V1 only APIs.")
469  def testBasicsWithListOfVariables(self):
470    save_path = os.path.join(self.get_temp_dir(), "basics_with_list")
471
472    with self.session(graph=ops_lib.Graph()) as sess:
473      # Build a graph with 2 parameter nodes, and Save and
474      # Restore nodes for them.
475      v0 = variables.VariableV1(10.0, name="v0")
476      v1 = variables.VariableV1(20.0, name="v1")
477      v2 = saver_test_utils.CheckpointedOp(name="v2")
478      v2_init = v2.insert("k1", 30.0)
479      save = saver_module.Saver([v0, v1, v2.saveable])
480      self.evaluate(variables.global_variables_initializer())
481      v2_init.run()
482
483      # Check that the parameter nodes have been initialized.
484      self.assertEqual(10.0, self.evaluate(v0))
485      self.assertEqual(20.0, self.evaluate(v1))
486      self.assertEqual(b"k1", self.evaluate(v2.keys()))
487      self.assertEqual(30.0, self.evaluate(v2.values()))
488
489      # Save the initialized values in the file at "save_path"
490      val = save.save(sess, save_path)
491      self.assertTrue(isinstance(val, six.string_types))
492      self.assertEqual(save_path, val)
493
494    # Start a second session.  In that session the variables
495    # have not been initialized either.
496    with self.session(graph=ops_lib.Graph()) as sess:
497      v0 = variables.VariableV1(-1.0, name="v0")
498      v1 = variables.VariableV1(-1.0, name="v1")
499      v2 = saver_test_utils.CheckpointedOp(name="v2")
500      save = saver_module.Saver([v0, v1, v2.saveable])
501
502      with self.assertRaisesWithPredicateMatch(
503          errors_impl.OpError, lambda e: "uninitialized value v0" in e.message):
504        self.evaluate(v0)
505      with self.assertRaisesWithPredicateMatch(
506          errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):
507        self.evaluate(v1)
508      self.assertEqual(0, len(self.evaluate(v2.keys())))
509      self.assertEqual(0, len(self.evaluate(v2.values())))
510
511      # Restore the saved values in the parameter nodes.
512      save.restore(sess, save_path)
513      # Check that the parameter nodes have been restored.
514      self.assertEqual(10.0, self.evaluate(v0))
515      self.assertEqual(20.0, self.evaluate(v1))
516      self.assertEqual(b"k1", self.evaluate(v2.keys()))
517      self.assertEqual(30.0, self.evaluate(v2.values()))
518
519    # Build another graph with 2 nodes, initialized
520    # differently, and a Restore node for them.
521    with self.session(graph=ops_lib.Graph()) as sess:
522      v0_2 = variables.VariableV1(1000.0, name="v0")
523      v1_2 = variables.VariableV1(2000.0, name="v1")
524      v2_2 = saver_test_utils.CheckpointedOp(name="v2")
525      save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable])
526      v2_2.insert("k1000", 3000.0).run()
527      self.evaluate(variables.global_variables_initializer())
528
529      # Check that the parameter nodes have been initialized.
530      self.assertEqual(1000.0, self.evaluate(v0_2))
531      self.assertEqual(2000.0, self.evaluate(v1_2))
532      self.assertEqual(b"k1000", self.evaluate(v2_2.keys()))
533      self.assertEqual(3000.0, self.evaluate(v2_2.values()))
534      # Restore the values saved earlier in the parameter nodes.
535      save2.restore(sess, save_path)
536      # Check that the parameter nodes have been restored.
537      self.assertEqual(10.0, self.evaluate(v0_2))
538      self.assertEqual(20.0, self.evaluate(v1_2))
539      self.assertEqual(b"k1", self.evaluate(v2_2.keys()))
540      self.assertEqual(30.0, self.evaluate(v2_2.values()))
541
542  def _SaveAndLoad(self, var_name, var_value, other_value, save_path):
543    with self.session(graph=ops_lib.Graph()) as sess:
544      var = resource_variable_ops.ResourceVariable(var_value, name=var_name)
545      save = saver_module.Saver({var_name: var})
546      if not context.executing_eagerly():
547        self.evaluate(var.initializer)
548      val = save.save(sess, save_path)
549      self.assertEqual(save_path, val)
550    with self.session(graph=ops_lib.Graph()) as sess:
551      var = resource_variable_ops.ResourceVariable(other_value, name=var_name)
552      save = saver_module.Saver({var_name: var})
553      save.restore(sess, save_path)
554      self.assertAllClose(var_value, self.evaluate(var))
555
556  def testCacheRereadsFile(self):
557    save_path = os.path.join(self.get_temp_dir(), "cache_rereads")
558    # Save and reload one Variable named "var0".
559    self._SaveAndLoad("var0", 0.0, 1.0, save_path)
560    # Save and reload one Variable named "var1" in the same file.
561    # The cached readers should know to re-read the file.
562    self._SaveAndLoad("var1", 1.1, 2.2, save_path)
563
564  def testAllowEmpty(self):
565    save_path = os.path.join(self.get_temp_dir(), "allow_empty")
566    # train.Saver is V1 only API.
567    with ops_lib.Graph().as_default(), self.cached_session() as sess:
568      _ = constant_op.constant(1)
569      save = saver_module.Saver(allow_empty=True)
570      val = save.save(sess, save_path)
571      self.assertIsNone(val)
572    with ops_lib.Graph().as_default(), self.cached_session() as sess:
573      save = saver_module.Saver(allow_empty=True)
574      save.restore(sess, save_path)
575
576  def testGPU(self):
577    if not test.is_gpu_available():
578      return
579    save_path = os.path.join(self.get_temp_dir(), "gpu")
580    with session.Session("", graph=ops_lib.Graph()) as sess:
581      with sess.graph.device(test.gpu_device_name()):
582        v0_1 = variables.VariableV1(123.45)
583      save = saver_module.Saver({"v0": v0_1})
584      self.evaluate(variables.global_variables_initializer())
585      save.save(sess, save_path)
586
587    with session.Session("", graph=ops_lib.Graph()) as sess:
588      with sess.graph.device(test.gpu_device_name()):
589        v0_2 = variables.VariableV1(543.21)
590      save = saver_module.Saver({"v0": v0_2})
591      self.evaluate(variables.global_variables_initializer())
592
593  def testSharedServerOnGPU(self):
594    if not test.is_gpu_available():
595      return
596    save_path = os.path.join(self.get_temp_dir(), "gpu")
597    with session.Session("", graph=ops_lib.Graph()) as sess:
598      with sess.graph.device(test.gpu_device_name()):
599        v0_1 = variables.VariableV1(123.45)
600      save = saver_module.Saver({"v0": v0_1}, sharded=True, allow_empty=True)
601      self.evaluate(variables.global_variables_initializer())
602      save.save(sess, save_path)
603
604    with session.Session("", graph=ops_lib.Graph()) as sess:
605      with sess.graph.device(test.gpu_device_name()):
606        v0_2 = variables.VariableV1(543.21)
607      save = saver_module.Saver({"v0": v0_2}, sharded=True, allow_empty=True)
608      self.evaluate(variables.global_variables_initializer())
609
610  def testVariables(self):
611    save_path = os.path.join(self.get_temp_dir(), "variables")
612    with session.Session("", graph=ops_lib.Graph()) as sess:
613      one = variables.VariableV1(1.0)
614      twos = variables.VariableV1([2.0, 2.0, 2.0])
615      v2 = saver_test_utils.CheckpointedOp(name="v2")
616      init = variables.global_variables_initializer()
617      save = saver_module.Saver()
618      init.run()
619      v2.insert("k1", 3.0).run()
620      save.save(sess, save_path)
621
622    with session.Session("", graph=ops_lib.Graph()) as sess:
623      one = variables.VariableV1(0.0)
624      twos = variables.VariableV1([0.0, 0.0, 0.0])
625      v2 = saver_test_utils.CheckpointedOp(name="v2")
626      # Saver with no arg, defaults to 'all variables'.
627      save = saver_module.Saver()
628      save.restore(sess, save_path)
629      self.assertAllClose(1.0, self.evaluate(one))
630      self.assertAllClose([2.0, 2.0, 2.0], self.evaluate(twos))
631      self.assertEqual(b"k1", self.evaluate(v2.keys()))
632      self.assertEqual(3.0, self.evaluate(v2.values()))
633
634  def testVarListShouldBeEmptyInDeferredBuild(self):
635    with ops_lib.Graph().as_default():
636      v = variables.VariableV1(1.0)
637      with self.assertRaisesRegex(ValueError, "defer_build"):
638        saver_module.Saver([v], defer_build=True)
639
640  def testBuildShouldBeCalledBeforeSaveInCaseOfDeferBuild(self):
641    save_path = os.path.join(self.get_temp_dir(), "error_deferred_build")
642    with ops_lib.Graph().as_default(), session.Session() as sess:
643      variables.VariableV1(1.0)
644      saver = saver_module.Saver(defer_build=True)
645      with self.assertRaisesRegex(RuntimeError, "build"):
646        saver.save(sess, save_path)
647
648  def testDeferredBuild(self):
649    save_path = os.path.join(self.get_temp_dir(), "deferred_build")
650    with session.Session("", graph=ops_lib.Graph()) as sess:
651      one = variables.VariableV1(1.0)
652      save = saver_module.Saver(defer_build=True)
653      # if build is not deferred, saver cannot save the `twos`.
654      twos = variables.VariableV1([2.0, 2.0, 2.0])
655      init = variables.global_variables_initializer()
656      save.build()
657      init.run()
658      save.save(sess, save_path)
659
660    with session.Session("", graph=ops_lib.Graph()) as sess:
661      one = variables.VariableV1(0.0)
662      twos = variables.VariableV1([0.0, 0.0, 0.0])
663      # Saver with no arg, defaults to 'all variables'.
664      save = saver_module.Saver()
665      save.restore(sess, save_path)
666      self.assertAllClose(1.0, self.evaluate(one))
667      self.assertAllClose([2.0, 2.0, 2.0], self.evaluate(twos))
668
669  @test_util.run_v1_only("train.Saver is V1 only API.")
670  def testReshape(self):
671    save_path = os.path.join(self.get_temp_dir(), "variables_reshape")
672    with session.Session("", graph=ops_lib.Graph()) as sess:
673      var = variables.VariableV1([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
674      init = variables.global_variables_initializer()
675      save = saver_module.Saver()
676      init.run()
677      save.save(sess, save_path)
678
679    # Error when restoring with default reshape=False
680    with session.Session("", graph=ops_lib.Graph()) as sess:
681      var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
682      save = saver_module.Saver()
683      with self.assertRaisesRegex(
684          errors_impl.InvalidArgumentError,
685          "Assign requires shapes of both tensors to match."):
686        save.restore(sess, save_path)
687
688    # Restored to new shape with reshape=True
689    with session.Session("", graph=ops_lib.Graph()) as sess:
690      var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
691      save = saver_module.Saver(reshape=True)
692      save.restore(sess, save_path)
693      self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
694                          self.evaluate(var))
695
696  @test_util.run_in_graph_and_eager_modes
697  def testSaveWithGlobalStep(self, pad_step_number=False):
698    save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step")
699    global_step_int = 5
700    # Save and reload one Variable named "var0".
701    self._SaveAndLoad("var0", 0.0, 1.0, save_path)
702    for use_tensor in [True, False]:
703      with self.session(graph=ops_lib.Graph()):
704        var = resource_variable_ops.ResourceVariable(1.0, name="var0")
705        save = saver_module.Saver(
706            {
707                var._shared_name: var
708            }, pad_step_number=pad_step_number)
709        if context.executing_eagerly():
710          sess = None
711        else:
712          self.evaluate(var.initializer)
713          sess = ops_lib.get_default_session()
714        if use_tensor:
715          global_step = constant_op.constant(global_step_int)
716          val = save.save(sess, save_path, global_step=global_step)
717        else:
718          val = save.save(sess, save_path, global_step=global_step_int)
719        if pad_step_number:
720          expected_save_path = "%s-%s" % (save_path,
721                                          "{:08d}".format(global_step_int))
722        else:
723          expected_save_path = "%s-%d" % (save_path, global_step_int)
724        self.assertEqual(expected_save_path, val)
725
726  def testSaveWithGlobalStepWithPadding(self):
727    self.testSaveWithGlobalStep(pad_step_number=True)
728
729  def testSaveToNonexistingPath(self):
730    file_io.write_string_to_file(
731        os.path.join(self.get_temp_dir(), "actually_a_file"), "")
732    paths = [
733        os.path.join(self.get_temp_dir(), "nonexisting_dir/path"),
734        os.path.join(self.get_temp_dir(), "other_nonexisting_dir/path1/path2"),
735        os.path.join(self.get_temp_dir(), "actually_a_file/path"),
736    ]
737
738    for save_path in paths:
739      # Build a graph with 2 parameter nodes, and Save and
740      # Restore nodes for them.
741      v0 = variables.VariableV1(10.0, name="v0")
742      v1 = variables.VariableV1(20.0, name="v1")
743      save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
744      init_all_op = variables.global_variables_initializer()
745
746      # In the case where the parent directory doesn't exist, whether or not the
747      # save succeeds or fails is implementation dependent.  Therefore we allow
748      # both cases.
749      try:
750        with self.cached_session() as sess:
751          # Initialize all variables
752          self.evaluate(init_all_op)
753
754          # Check that the parameter nodes have been initialized.
755          self.assertEqual(10.0, self.evaluate(v0))
756          self.assertEqual(20.0, self.evaluate(v1))
757
758          # Save the graph.
759          save.save(sess, save_path)
760
761        with self.cached_session() as sess:
762          # Restore the saved values in the parameter nodes.
763          save.restore(sess, save_path)
764          # Check that the parameter nodes have been restored.
765          self.assertEqual(10.0, self.evaluate(v0))
766          self.assertEqual(20.0, self.evaluate(v1))
767      except ValueError as exc:
768        error_msg_template = "Parent directory of {} doesn't exist, can't save."
769        self.assertEqual(error_msg_template.format(save_path), str(exc))
770
771  def testSaveToURI(self):
772    # ParseURI functions don't work on Windows yet.
773    # TODO(jhseu): Remove this check when it works.
774    if os.name == "nt":
775      self.skipTest("Local URI support doesn't work on Windows")
776    save_path = "file://" + os.path.join(self.get_temp_dir(), "uri")
777
778    # Build a graph with 2 parameter nodes, and Save and
779    # Restore nodes for them.
780    v0 = variables.VariableV1(10.0, name="v0")
781    v1 = variables.VariableV1(20.0, name="v1")
782    save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
783    init_all_op = variables.global_variables_initializer()
784
785    with self.cached_session() as sess:
786      # Initialize all variables
787      self.evaluate(init_all_op)
788
789      # Check that the parameter nodes have been initialized.
790      self.assertEqual(10.0, self.evaluate(v0))
791      self.assertEqual(20.0, self.evaluate(v1))
792      save.save(sess, save_path)
793
794  def testSaveRestoreAndValidateVariableDtype(self):
795    for variable_op in [
796        variables.Variable, resource_variable_ops.ResourceVariable
797    ]:
798      save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")
799
800      # Build the first session.
801      with self.session(graph=ops_lib.Graph()) as sess:
802        v0 = variable_op(10.0, name="v0", dtype=dtypes.float32)
803
804        if not context.executing_eagerly():
805          self.evaluate([variables.global_variables_initializer()])
806
807        save = saver_module.Saver({"v0": v0})
808        save.save(sess, save_path)
809
810      # Start a second session.
811      with self.session(graph=ops_lib.Graph()) as sess:
812        v0_wrong_dtype = variable_op(1, name="v0", dtype=dtypes.int32)
813        # Restore the saved value with different dtype
814        # in the parameter nodes.
815        save = saver_module.Saver({"v0": v0_wrong_dtype})
816        with self.assertRaisesRegex(errors.InvalidArgumentError,
817                                    "original dtype"):
818          save.restore(sess, save_path)
819
820  # Test restoring large tensors (triggers a thread pool)
821  def testRestoreLargeTensors(self):
822    save_dir = self.get_temp_dir()
823    def _model():
824      small_v = [variable_scope.get_variable(
825          "small%d" % i, shape=[10, 2], use_resource=True) for i in range(5)]
826      large_v = [variable_scope.get_variable(
827          "large%d" % i, shape=[32000, 1000], use_resource=True)
828                 for i in range(3)]
829      return small_v + large_v
830
831    save_graph = ops_lib.Graph()
832    with save_graph.as_default(), self.session(graph=save_graph) as sess:
833      orig_vars = _model()
834      self.evaluate(variables.global_variables_initializer())
835      save = saver_module.Saver(max_to_keep=1)
836      self.evaluate(variables.global_variables_initializer())
837      save.save(sess, save_dir)
838      orig_vals = self.evaluate(orig_vars)
839
840    restore_graph = ops_lib.Graph()
841    with restore_graph.as_default(), self.session(
842        graph=restore_graph) as sess:
843      restored_vars = _model()
844      save = saver_module.Saver(max_to_keep=1)
845      save.restore(sess, save_dir)
846      restored_vals = self.evaluate(restored_vars)
847
848    for orig, restored in zip(orig_vals, restored_vals):
849      self.assertAllEqual(orig, restored)
850
851
852class SaveRestoreShardedTest(test.TestCase):
853
854  _WRITE_VERSION = saver_pb2.SaverDef.V1
855
856  def _get_test_dir(self, dirname):
857    test_dir = os.path.join(self.get_temp_dir(), dirname)
858    gfile.MakeDirs(test_dir)
859    return test_dir
860
861  def testBasics(self):
862    save_path = os.path.join(self.get_temp_dir(), "sharded_basics")
863
864    # Build a graph with 2 parameter nodes on different devices.
865    with session.Session(
866        target="",
867        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
868      with sess.graph.device("/cpu:0"):
869        v0 = variables.VariableV1(10, name="v0")
870        t0 = saver_test_utils.CheckpointedOp(name="t0")
871      with sess.graph.device("/cpu:1"):
872        v1 = variables.VariableV1(20, name="v1")
873        t1 = saver_test_utils.CheckpointedOp(name="t1")
874      save = saver_module.Saver(
875          {
876              "v0": v0,
877              "v1": v1,
878              "t0": t0.saveable,
879              "t1": t1.saveable
880          },
881          write_version=self._WRITE_VERSION,
882          sharded=True)
883      self.evaluate(variables.global_variables_initializer())
884      t0.insert("k1", 30.0).run()
885      t1.insert("k2", 40.0).run()
886      val = save.save(sess, save_path)
887      if save._write_version is saver_pb2.SaverDef.V1:
888        self.assertEqual(save_path + "-?????-of-00002", val)
889      else:
890        self.assertEqual(save_path, val)
891      meta_graph_filename = checkpoint_management.meta_graph_filename(val)
892      self.assertEqual(save_path + ".meta", meta_graph_filename)
893
894    if save._write_version is saver_pb2.SaverDef.V1:
895      # Restore different ops from shard 0 of the saved files.
896      with session.Session(
897          target="",
898          config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
899        with sess.graph.device("/cpu:0"):
900          v0 = variables.VariableV1(111, name="v0")
901          t0 = saver_test_utils.CheckpointedOp(name="t0")
902        save = saver_module.Saver(
903            {
904                "v0": v0,
905                "t0": t0.saveable
906            },
907            write_version=self._WRITE_VERSION,
908            sharded=True)
909        self.evaluate(variables.global_variables_initializer())
910        t0.insert("k11", 33.0).run()
911        self.assertEqual(111, self.evaluate(v0))
912        self.assertEqual(b"k11", self.evaluate(t0.keys()))
913        self.assertEqual(33.0, self.evaluate(t0.values()))
914        save.restore(sess, save_path + "-00000-of-00002")
915        self.assertEqual(10, self.evaluate(v0))
916        self.assertEqual(b"k1", self.evaluate(t0.keys()))
917        self.assertEqual(30.0, self.evaluate(t0.values()))
918
919      # Restore different ops from shard 1 of the saved files.
920      with session.Session(
921          target="",
922          config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
923        with sess.graph.device("/cpu:0"):
924          v1 = variables.VariableV1(222)
925          t1 = saver_test_utils.CheckpointedOp(name="t1")
926        save = saver_module.Saver(
927            {
928                "v1": v1,
929                "t1": t1.saveable
930            },
931            write_version=self._WRITE_VERSION,
932            sharded=True)
933        self.evaluate(variables.global_variables_initializer())
934        t1.insert("k22", 44.0).run()
935        self.assertEqual(222, self.evaluate(v1))
936        self.assertEqual(b"k22", self.evaluate(t1.keys()))
937        self.assertEqual(44.0, self.evaluate(t1.values()))
938        save.restore(sess, save_path + "-00001-of-00002")
939        self.assertEqual(20, self.evaluate(v1))
940        self.assertEqual(b"k2", self.evaluate(t1.keys()))
941        self.assertEqual(40.0, self.evaluate(t1.values()))
942
943    # Now try a restore with the sharded filename.
944    with session.Session(
945        target="",
946        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
947      with sess.graph.device("/cpu:0"):
948        v0 = variables.VariableV1(111, name="v0")
949        t0 = saver_test_utils.CheckpointedOp(name="t0")
950      with sess.graph.device("/cpu:1"):
951        v1 = variables.VariableV1(222, name="v1")
952        t1 = saver_test_utils.CheckpointedOp(name="t1")
953      save = saver_module.Saver(
954          {
955              "v0": v0,
956              "v1": v1,
957              "t0": t0.saveable,
958              "t1": t1.saveable
959          },
960          write_version=self._WRITE_VERSION,
961          sharded=True)
962      self.evaluate(variables.global_variables_initializer())
963      t0.insert("k11", 33.0).run()
964      t1.insert("k22", 44.0).run()
965      self.assertEqual(111, self.evaluate(v0))
966      self.assertEqual(222, self.evaluate(v1))
967      self.assertEqual(b"k11", self.evaluate(t0.keys()))
968      self.assertEqual(33.0, self.evaluate(t0.values()))
969      self.assertEqual(b"k22", self.evaluate(t1.keys()))
970      self.assertEqual(44.0, self.evaluate(t1.values()))
971      save_path = os.path.join(self.get_temp_dir(), "sharded_basics")
972      if save._write_version is saver_pb2.SaverDef.V1:
973        save.restore(sess, save_path + "-?????-of-?????")
974      else:
975        save.restore(sess, save_path)
976      self.assertEqual(10, self.evaluate(v0))
977      self.assertEqual(20, self.evaluate(v1))
978      self.assertEqual(b"k1", self.evaluate(t0.keys()))
979      self.assertEqual(30.0, self.evaluate(t0.values()))
980      self.assertEqual(b"k2", self.evaluate(t1.keys()))
981      self.assertEqual(40.0, self.evaluate(t1.values()))
982
983    if save._write_version is saver_pb2.SaverDef.V1:
984      self.assertEqual(
985          checkpoint_management.latest_checkpoint(self.get_temp_dir()),
986          os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002"))
987    else:
988      self.assertEqual(
989          checkpoint_management.latest_checkpoint(self.get_temp_dir()),
990          os.path.join(self.get_temp_dir(), "sharded_basics"))
991
992  def testSaverDef(self):
993    # train.Saver is V1 only API.
994    with ops_lib.Graph().as_default(), self.cached_session():
995      v0 = variables.VariableV1(123, name="v0")
996      save = saver_module.Saver({"v0": v0}, sharded=True)
997      sd = save.as_saver_def()
998      self.assertTrue(sd.sharded)
999
1000  def _testPartitionedVariables(self, use_resource):
1001    var_full_shape = [10, 3]
1002    # Allows save/restore mechanism to work w/ different slicings.
1003    var_name = "my_var"
1004    saved_dir = self._get_test_dir("partitioned_variables")
1005    saved_path = os.path.join(saved_dir, "ckpt")
1006
1007    call_saver_with_dict = False  # updated by test loop below
1008
1009    def _save(partitioner=None):
1010      # train.Saver is V1 only API.
1011      with ops_lib.Graph().as_default(), self.session() as sess:
1012        # Calls .eval() to return the ndarray that makes up the full variable.
1013        rnd = random_ops.random_uniform(var_full_shape).eval()
1014
1015        if partitioner:
1016          vs = [
1017              variable_scope.get_variable(
1018                  var_name,
1019                  shape=var_full_shape,
1020                  initializer=rnd,
1021                  partitioner=partitioner,
1022                  use_resource=use_resource)
1023          ]
1024        else:
1025          if use_resource:
1026            vs = [resource_variable_ops.ResourceVariable(rnd, name=var_name)]
1027          else:
1028            vs = [variables.VariableV1(rnd, name=var_name)]
1029
1030        self.evaluate(variables.global_variables_initializer())
1031        if call_saver_with_dict:
1032          saver = saver_module.Saver({var_name: vs[0]})
1033        else:
1034          saver = saver_module.Saver(vs)
1035        actual_path = saver.save(sess, saved_path)
1036        self.assertEqual(saved_path, actual_path)
1037
1038        return rnd
1039
1040    def _restore(partitioner=None):
1041      # train.Saver is V1 only API.
1042      with ops_lib.Graph().as_default(), self.session() as sess:
1043        if partitioner:
1044          new_vs = [
1045              variable_scope.get_variable(
1046                  var_name,
1047                  shape=var_full_shape,
1048                  initializer=array_ops.zeros(var_full_shape),
1049                  partitioner=partitioner)
1050          ]
1051        else:
1052          new_vs = [
1053              variables.VariableV1(
1054                  array_ops.zeros(
1055                      shape=var_full_shape),  # != original contents.
1056                  name=var_name)
1057          ]
1058
1059        self.evaluate(variables.global_variables_initializer())
1060        if call_saver_with_dict:
1061          saver = saver_module.Saver({
1062              var_name: new_vs[0]
1063          })
1064        else:
1065          saver = saver_module.Saver(new_vs)
1066        saver.restore(sess, saved_path)
1067
1068        if partitioner:
1069          return new_vs[0].as_tensor().eval()
1070        else:
1071          return new_vs[0].eval()
1072
1073    for call_saver_with_dict in {False, True}:
1074      # Save PartitionedVariable and restore into full variable.
1075      saved_full = _save(
1076          partitioner=partitioned_variables.fixed_size_partitioner(
1077              num_shards=2))
1078      restored_full = _restore()
1079      self.assertAllEqual(saved_full, restored_full)
1080
1081      # Restores into the same number of partitions.
1082      restored_full = _restore(
1083          partitioner=partitioned_variables.fixed_size_partitioner(
1084              num_shards=2))
1085      self.assertAllEqual(saved_full, restored_full)
1086
1087      # Restores into a different number of partitions.
1088      restored_full = _restore(
1089          partitioner=partitioned_variables.fixed_size_partitioner(
1090              num_shards=3))
1091      self.assertAllEqual(saved_full, restored_full)
1092
1093      # Now, saves a full variable and restores PartitionedVariable.
1094      saved_full = _save()
1095      restored_full = _restore(
1096          partitioner=partitioned_variables.fixed_size_partitioner(
1097              num_shards=3))
1098      self.assertAllEqual(saved_full, restored_full)
1099
1100  def testPartitionedVariable(self):
1101    self._testPartitionedVariables(use_resource=False)
1102
1103  def testPartitionedResourceVariable(self):
1104    self._testPartitionedVariables(use_resource=True)
1105
1106
1107class SaveRestoreShardedTestV2(SaveRestoreShardedTest):
1108  _WRITE_VERSION = saver_pb2.SaverDef.V2
1109
1110  def testIterators(self):
1111    save_path = os.path.join(self.get_temp_dir(), "sharded_iterators")
1112
1113    # Build a graph with 2 parameter nodes on different devices and save.
1114    with session.Session(
1115        target="",
1116        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
1117      with sess.graph.device("/cpu:0"):
1118        ds0 = dataset_ops.Dataset.range(10)
1119        it0 = dataset_ops.make_initializable_iterator(ds0)
1120        get_next0 = it0.get_next()
1121      saveable0 = iterator_ops._IteratorSaveable(
1122          it0._iterator_resource, name="saveable_it0")
1123
1124      with sess.graph.device("/cpu:1"):
1125        ds1 = dataset_ops.Dataset.range(20)
1126        it1 = dataset_ops.make_initializable_iterator(ds1)
1127        get_next1 = it1.get_next()
1128      saveable1 = iterator_ops._IteratorSaveable(
1129          it1._iterator_resource, name="saveable_it1")
1130      saver = saver_module.Saver({
1131          "it0": saveable0,
1132          "it1": saveable1
1133      },
1134                                 write_version=self._WRITE_VERSION,
1135                                 sharded=True)
1136      self.evaluate(it0.initializer)
1137      self.evaluate(it1.initializer)
1138      self.assertEqual(0, self.evaluate(get_next0))
1139      self.assertEqual(1, self.evaluate(get_next0))
1140      self.assertEqual(0, self.evaluate(get_next1))
1141      val = saver.save(sess, save_path)
1142      self.assertEqual(save_path, val)
1143      data_files = glob.glob(save_path + ".data*")
1144      self.assertEqual(2, len(data_files))
1145
1146    # Restore
1147    with session.Session(
1148        target="",
1149        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
1150      with sess.graph.device("/cpu:0"):
1151        ds0 = dataset_ops.Dataset.range(10)
1152        it0 = dataset_ops.make_initializable_iterator(ds0)
1153        get_next0 = it0.get_next()
1154      saveable0 = iterator_ops._IteratorSaveable(
1155          it0._iterator_resource, name="saveable_it0")
1156
1157      with sess.graph.device("/cpu:1"):
1158        ds1 = dataset_ops.Dataset.range(20)
1159        it1 = dataset_ops.make_initializable_iterator(ds1)
1160        get_next1 = it1.get_next()
1161      saveable1 = iterator_ops._IteratorSaveable(
1162          it1._iterator_resource, name="saveable_it1")
1163      saver = saver_module.Saver({
1164          "it0": saveable0,
1165          "it1": saveable1
1166      },
1167                                 write_version=self._WRITE_VERSION,
1168                                 sharded=True)
1169      self.evaluate(it0.initializer)
1170      self.evaluate(it1.initializer)
1171      saver.restore(sess, save_path)
1172      self.assertEqual(2, self.evaluate(get_next0))
1173      self.assertEqual(1, self.evaluate(get_next1))
1174
1175  def testIteratorsUnshardedRestore(self):
1176    save_path = os.path.join(self.get_temp_dir(), "restore_unsharded_iterators")
1177
1178    # Build a graph with 2 parameter nodes on different devices and save.
1179    with session.Session(
1180        target="",
1181        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
1182      with sess.graph.device("/cpu:0"):
1183        ds0 = dataset_ops.Dataset.range(10)
1184        it0 = dataset_ops.make_initializable_iterator(ds0)
1185        get_next0 = it0.get_next()
1186      saveable0 = iterator_ops._IteratorSaveable(
1187          it0._iterator_resource, name="saveable_it0")
1188
1189      with sess.graph.device("/cpu:1"):
1190        ds1 = dataset_ops.Dataset.range(20)
1191        it1 = dataset_ops.make_initializable_iterator(ds1)
1192        get_next1 = it1.get_next()
1193      saveable1 = iterator_ops._IteratorSaveable(
1194          it1._iterator_resource, name="saveable_it1")
1195      saver = saver_module.Saver({
1196          "it0": saveable0,
1197          "it1": saveable1
1198      },
1199                                 write_version=self._WRITE_VERSION,
1200                                 sharded=True)
1201      self.evaluate(it0.initializer)
1202      self.evaluate(it1.initializer)
1203      self.assertEqual(0, self.evaluate(get_next0))
1204      self.assertEqual(1, self.evaluate(get_next0))
1205      self.assertEqual(0, self.evaluate(get_next1))
1206      val = saver.save(sess, save_path)
1207      self.assertEqual(save_path, val)
1208      data_files = glob.glob(save_path + ".data*")
1209      self.assertEqual(2, len(data_files))
1210
1211    # Restore
1212    with session.Session(
1213        target="",
1214        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
1215      with sess.graph.device("/cpu:0"):
1216        ds0 = dataset_ops.Dataset.range(10)
1217        it0 = dataset_ops.make_initializable_iterator(ds0)
1218        get_next0 = it0.get_next()
1219      saveable0 = iterator_ops._IteratorSaveable(
1220          it0._iterator_resource, name="saveable_it0")
1221
1222      with sess.graph.device("/cpu:1"):
1223        ds1 = dataset_ops.Dataset.range(20)
1224        it1 = dataset_ops.make_initializable_iterator(ds1)
1225        get_next1 = it1.get_next()
1226      saveable1 = iterator_ops._IteratorSaveable(
1227          it1._iterator_resource, name="saveable_it1")
1228      saver = saver_module.Saver({
1229          "it0": saveable0,
1230          "it1": saveable1
1231      },
1232                                 write_version=self._WRITE_VERSION,
1233                                 sharded=False)
1234      self.evaluate(it0.initializer)
1235      self.evaluate(it1.initializer)
1236      saver.restore(sess, save_path)
1237      self.assertEqual(2, self.evaluate(get_next0))
1238      self.assertEqual(1, self.evaluate(get_next1))
1239
1240
1241class MaxToKeepTest(test.TestCase):
1242
1243  def _get_test_dir(self, dirname):
1244    test_dir = os.path.join(self.get_temp_dir(), dirname)
1245    gfile.MakeDirs(test_dir)
1246    return test_dir
1247
1248  def assertCheckpointState(self, model_checkpoint_path,
1249                            all_model_checkpoint_paths, save_dir):
1250    checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir)
1251    self.assertEqual(checkpoint_state.model_checkpoint_path,
1252                     model_checkpoint_path)
1253    self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
1254                     all_model_checkpoint_paths)
1255
1256  def testMaxToKeepEager(self):
1257    with context.eager_mode():
1258      save_dir = self._get_test_dir("max_to_keep_eager")
1259
1260      v = variable_scope.variable(10.0, name="v")
1261      save = saver_module.Saver({"v": v}, max_to_keep=2)
1262      self.evaluate(variables.global_variables_initializer())
1263      if not context.executing_eagerly():
1264        self.assertEqual([], save.last_checkpoints)
1265
1266      s1 = save.save(None, os.path.join(save_dir, "s1"))
1267      self.assertEqual([s1], save.last_checkpoints)
1268      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1269      self.assertCheckpointState(
1270          model_checkpoint_path=s1,
1271          all_model_checkpoint_paths=[s1],
1272          save_dir=save_dir)
1273
1274      s2 = save.save(None, os.path.join(save_dir, "s2"))
1275      self.assertEqual([s1, s2], save.last_checkpoints)
1276      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1277      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1278      self.assertCheckpointState(
1279          model_checkpoint_path=s2,
1280          all_model_checkpoint_paths=[s1, s2],
1281          save_dir=save_dir)
1282
1283      s3 = save.save(None, os.path.join(save_dir, "s3"))
1284      self.assertEqual([s2, s3], save.last_checkpoints)
1285      self.assertFalse(checkpoint_management.checkpoint_exists(s1))
1286      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1287      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
1288      self.assertCheckpointState(
1289          model_checkpoint_path=s3,
1290          all_model_checkpoint_paths=[s2, s3],
1291          save_dir=save_dir)
1292
1293      # Create a second helper, identical to the first.
1294      save2 = saver_module.Saver({"v": v}, max_to_keep=2)
1295      save2.set_last_checkpoints(save.last_checkpoints)
1296
1297      # Exercise the first helper.
1298
1299      # Adding s2 again (old s2 is removed first, then new s2 appended)
1300      s2 = save.save(None, os.path.join(save_dir, "s2"))
1301      self.assertEqual([s3, s2], save.last_checkpoints)
1302      self.assertFalse(checkpoint_management.checkpoint_exists(s1))
1303      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
1304      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1305      self.assertCheckpointState(
1306          model_checkpoint_path=s2,
1307          all_model_checkpoint_paths=[s3, s2],
1308          save_dir=save_dir)
1309
1310      # Adding s1 (s3 should now be deleted as oldest in list)
1311      s1 = save.save(None, os.path.join(save_dir, "s1"))
1312      self.assertEqual([s2, s1], save.last_checkpoints)
1313      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
1314      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1315      self.assertCheckpointState(
1316          model_checkpoint_path=s1,
1317          all_model_checkpoint_paths=[s2, s1],
1318          save_dir=save_dir)
1319
1320      s2 = save2.save(None, os.path.join(save_dir, "s2"))
1321      self.assertEqual([s3, s2], save2.last_checkpoints)
1322      # Created by the first helper.
1323      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1324      # Deleted by the first helper.
1325      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
1326
1327  def testNonSharded(self):
1328    save_dir = self._get_test_dir("max_to_keep_non_sharded")
1329
1330    # train.Saver is V1 only API.
1331    with ops_lib.Graph().as_default(), self.cached_session() as sess:
1332      v = variables.VariableV1(10.0, name="v")
1333      save = saver_module.Saver({"v": v}, max_to_keep=2)
1334      self.evaluate(variables.global_variables_initializer())
1335      self.assertEqual([], save.last_checkpoints)
1336
1337      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1338      self.assertEqual([s1], save.last_checkpoints)
1339      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1340      self.assertCheckpointState(
1341          model_checkpoint_path=s1,
1342          all_model_checkpoint_paths=[s1],
1343          save_dir=save_dir)
1344
1345      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1346      self.assertEqual([s1, s2], save.last_checkpoints)
1347      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1348      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1349      self.assertCheckpointState(
1350          model_checkpoint_path=s2,
1351          all_model_checkpoint_paths=[s1, s2],
1352          save_dir=save_dir)
1353
1354      s3 = save.save(sess, os.path.join(save_dir, "s3"))
1355      self.assertEqual([s2, s3], save.last_checkpoints)
1356      self.assertFalse(checkpoint_management.checkpoint_exists(s1))
1357      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1358      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
1359      self.assertCheckpointState(
1360          model_checkpoint_path=s3,
1361          all_model_checkpoint_paths=[s2, s3],
1362          save_dir=save_dir)
1363
1364      # Create a second helper, identical to the first.
1365      save2 = saver_module.Saver(saver_def=save.as_saver_def())
1366      save2.set_last_checkpoints(save.last_checkpoints)
1367
1368      # Create a third helper, with the same configuration but no knowledge of
1369      # previous checkpoints.
1370      save3 = saver_module.Saver(saver_def=save.as_saver_def())
1371
1372      # Exercise the first helper.
1373
1374      # Adding s2 again (old s2 is removed first, then new s2 appended)
1375      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1376      self.assertEqual([s3, s2], save.last_checkpoints)
1377      self.assertFalse(checkpoint_management.checkpoint_exists(s1))
1378      self.assertFalse(
1379          checkpoint_management.checkpoint_exists(
1380              checkpoint_management.meta_graph_filename(s1)))
1381      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
1382      self.assertTrue(
1383          checkpoint_management.checkpoint_exists(
1384              checkpoint_management.meta_graph_filename(s3)))
1385      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1386      self.assertTrue(
1387          checkpoint_management.checkpoint_exists(
1388              checkpoint_management.meta_graph_filename(s2)))
1389      self.assertCheckpointState(
1390          model_checkpoint_path=s2,
1391          all_model_checkpoint_paths=[s3, s2],
1392          save_dir=save_dir)
1393
1394      # Adding s1 (s3 should now be deleted as oldest in list)
1395      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1396      self.assertEqual([s2, s1], save.last_checkpoints)
1397      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
1398      self.assertFalse(
1399          checkpoint_management.checkpoint_exists(
1400              checkpoint_management.meta_graph_filename(s3)))
1401      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1402      self.assertTrue(
1403          checkpoint_management.checkpoint_exists(
1404              checkpoint_management.meta_graph_filename(s2)))
1405      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1406      self.assertTrue(
1407          checkpoint_management.checkpoint_exists(
1408              checkpoint_management.meta_graph_filename(s1)))
1409      self.assertCheckpointState(
1410          model_checkpoint_path=s1,
1411          all_model_checkpoint_paths=[s2, s1],
1412          save_dir=save_dir)
1413
1414      # Exercise the second helper.
1415
1416      # Adding s2 again (old s2 is removed first, then new s2 appended)
1417      s2 = save2.save(sess, os.path.join(save_dir, "s2"))
1418      self.assertEqual([s3, s2], save2.last_checkpoints)
1419      # Created by the first helper.
1420      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1421      self.assertTrue(
1422          checkpoint_management.checkpoint_exists(
1423              checkpoint_management.meta_graph_filename(s1)))
1424      # Deleted by the first helper.
1425      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
1426      self.assertFalse(
1427          checkpoint_management.checkpoint_exists(
1428              checkpoint_management.meta_graph_filename(s3)))
1429      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1430      self.assertTrue(
1431          checkpoint_management.checkpoint_exists(
1432              checkpoint_management.meta_graph_filename(s2)))
1433      self.assertCheckpointState(
1434          model_checkpoint_path=s2,
1435          all_model_checkpoint_paths=[s3, s2],
1436          save_dir=save_dir)
1437
1438      # Adding s1 (s3 should now be deleted as oldest in list)
1439      s1 = save2.save(sess, os.path.join(save_dir, "s1"))
1440      self.assertEqual([s2, s1], save2.last_checkpoints)
1441      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
1442      self.assertFalse(
1443          checkpoint_management.checkpoint_exists(
1444              checkpoint_management.meta_graph_filename(s3)))
1445      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1446      self.assertTrue(
1447          checkpoint_management.checkpoint_exists(
1448              checkpoint_management.meta_graph_filename(s2)))
1449      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1450      self.assertTrue(
1451          checkpoint_management.checkpoint_exists(
1452              checkpoint_management.meta_graph_filename(s1)))
1453      self.assertCheckpointState(
1454          model_checkpoint_path=s1,
1455          all_model_checkpoint_paths=[s2, s1],
1456          save_dir=save_dir)
1457
1458      # Exercise the third helper.
1459
1460      # Adding s2 again (but helper is unaware of previous s2)
1461      s2 = save3.save(sess, os.path.join(save_dir, "s2"))
1462      self.assertEqual([s2], save3.last_checkpoints)
1463      # Created by the first helper.
1464      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1465      self.assertTrue(
1466          checkpoint_management.checkpoint_exists(
1467              checkpoint_management.meta_graph_filename(s1)))
1468      # Deleted by the first helper.
1469      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
1470      self.assertFalse(
1471          checkpoint_management.checkpoint_exists(
1472              checkpoint_management.meta_graph_filename(s3)))
1473      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1474      self.assertTrue(
1475          checkpoint_management.checkpoint_exists(
1476              checkpoint_management.meta_graph_filename(s2)))
1477      # Even though the file for s1 exists, this saver isn't aware of it, which
1478      # is why it doesn't end up in the checkpoint state.
1479      self.assertCheckpointState(
1480          model_checkpoint_path=s2,
1481          all_model_checkpoint_paths=[s2],
1482          save_dir=save_dir)
1483
1484      # Adding s1 (s3 should not be deleted because helper is unaware of it)
1485      s1 = save3.save(sess, os.path.join(save_dir, "s1"))
1486      self.assertEqual([s2, s1], save3.last_checkpoints)
1487      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
1488      self.assertFalse(
1489          checkpoint_management.checkpoint_exists(
1490              checkpoint_management.meta_graph_filename(s3)))
1491      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1492      self.assertTrue(
1493          checkpoint_management.checkpoint_exists(
1494              checkpoint_management.meta_graph_filename(s2)))
1495      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1496      self.assertTrue(
1497          checkpoint_management.checkpoint_exists(
1498              checkpoint_management.meta_graph_filename(s1)))
1499      self.assertCheckpointState(
1500          model_checkpoint_path=s1,
1501          all_model_checkpoint_paths=[s2, s1],
1502          save_dir=save_dir)
1503
1504  def testSharded(self):
1505    save_dir = self._get_test_dir("max_to_keep_sharded")
1506
1507    with session.Session(
1508        target="",
1509        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
1510      with sess.graph.device("/cpu:0"):
1511        v0 = variables.VariableV1(111, name="v0")
1512      with sess.graph.device("/cpu:1"):
1513        v1 = variables.VariableV1(222, name="v1")
1514      save = saver_module.Saver(
1515          {
1516              "v0": v0,
1517              "v1": v1
1518          }, sharded=True, max_to_keep=2)
1519      self.evaluate(variables.global_variables_initializer())
1520      self.assertEqual([], save.last_checkpoints)
1521
1522      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1523      self.assertEqual([s1], save.last_checkpoints)
1524      if save._write_version is saver_pb2.SaverDef.V1:
1525        self.assertEqual(2, len(gfile.Glob(s1)))
1526      else:
1527        self.assertEqual(4, len(gfile.Glob(s1 + "*")))
1528
1529      self.assertTrue(
1530          gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
1531
1532      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1533      self.assertEqual([s1, s2], save.last_checkpoints)
1534      if save._write_version is saver_pb2.SaverDef.V1:
1535        self.assertEqual(2, len(gfile.Glob(s1)))
1536      else:
1537        self.assertEqual(4, len(gfile.Glob(s1 + "*")))
1538      self.assertTrue(
1539          gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
1540      if save._write_version is saver_pb2.SaverDef.V1:
1541        self.assertEqual(2, len(gfile.Glob(s2)))
1542      else:
1543        self.assertEqual(4, len(gfile.Glob(s2 + "*")))
1544      self.assertTrue(
1545          gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
1546
1547      s3 = save.save(sess, os.path.join(save_dir, "s3"))
1548      self.assertEqual([s2, s3], save.last_checkpoints)
1549      self.assertEqual(0, len(gfile.Glob(s1 + "*")))
1550      self.assertFalse(
1551          gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
1552      if save._write_version is saver_pb2.SaverDef.V1:
1553        self.assertEqual(2, len(gfile.Glob(s2)))
1554      else:
1555        self.assertEqual(4, len(gfile.Glob(s2 + "*")))
1556      self.assertTrue(
1557          gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
1558      if save._write_version is saver_pb2.SaverDef.V1:
1559        self.assertEqual(2, len(gfile.Glob(s3)))
1560      else:
1561        self.assertEqual(4, len(gfile.Glob(s3 + "*")))
1562      self.assertTrue(
1563          gfile.Exists(checkpoint_management.meta_graph_filename(s3)))
1564
1565  def testNoMaxToKeep(self):
1566    save_dir = self._get_test_dir("no_max_to_keep")
1567    save_dir2 = self._get_test_dir("max_to_keep_0")
1568
1569    with self.cached_session() as sess:
1570      v = variables.VariableV1(10.0, name="v")
1571      self.evaluate(variables.global_variables_initializer())
1572
1573      # Test max_to_keep being None.
1574      save = saver_module.Saver({"v": v}, max_to_keep=None)
1575      self.assertEqual([], save.last_checkpoints)
1576      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1577      self.assertEqual([], save.last_checkpoints)
1578      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1579      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1580      self.assertEqual([], save.last_checkpoints)
1581      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1582
1583      # Test max_to_keep being 0.
1584      save2 = saver_module.Saver({"v": v}, max_to_keep=0)
1585      self.assertEqual([], save2.last_checkpoints)
1586      s1 = save2.save(sess, os.path.join(save_dir2, "s1"))
1587      self.assertEqual([], save2.last_checkpoints)
1588      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1589      s2 = save2.save(sess, os.path.join(save_dir2, "s2"))
1590      self.assertEqual([], save2.last_checkpoints)
1591      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1592
1593  def testNoMetaGraph(self):
1594    save_dir = self._get_test_dir("no_meta_graph")
1595
1596    with self.cached_session() as sess:
1597      v = variables.VariableV1(10.0, name="v")
1598      save = saver_module.Saver({"v": v})
1599      self.evaluate(variables.global_variables_initializer())
1600
1601      s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False)
1602      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1603      self.assertFalse(
1604          gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
1605
1606
1607class RecoverLastCheckpointsTest(test.TestCase):
1608
1609  def _get_test_dir(self, dirname):
1610    test_dir = os.path.join(self.get_temp_dir(), dirname)
1611    gfile.MakeDirs(test_dir)
1612    return test_dir
1613
1614  def assertCheckpointState(self, model_checkpoint_path,
1615                            all_model_checkpoint_paths, save_dir):
1616    checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir)
1617    self.assertEqual(checkpoint_state.model_checkpoint_path,
1618                     model_checkpoint_path)
1619    self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
1620                     all_model_checkpoint_paths)
1621
1622  def test_recover_last_checkpoints(self):
1623    with context.eager_mode():
1624      save_dir = self._get_test_dir("recover_last_checkpoints")
1625
1626      v = variable_scope.variable(10.0, name="v")
1627      save = saver_module.Saver({"v": v}, max_to_keep=10)
1628      self.evaluate(variables.global_variables_initializer())
1629      self.assertEqual([], save.last_checkpoints)
1630
1631      s1 = save.save(None, os.path.join(save_dir, "ckpt-1"))
1632      s2 = save.save(None, os.path.join(save_dir, "ckpt-2"))
1633      s3 = save.save(None, os.path.join(save_dir, "ckpt-3"))
1634      self.assertEqual([s1, s2, s3], save.last_checkpoints)
1635      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1636      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1637      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
1638      self.assertCheckpointState(
1639          model_checkpoint_path=s3,
1640          all_model_checkpoint_paths=[s1, s2, s3],
1641          save_dir=save_dir)
1642
1643      # Create another saver and recover last checkpoints.
1644      save2 = saver_module.Saver({"v": v}, max_to_keep=10)
1645      self.assertEqual([], save2.last_checkpoints)
1646      save2.recover_last_checkpoints([s1, s2, s3])
1647      self.assertEqual([s1, s2, s3], save2.last_checkpoints)
1648
1649      # Remove a checkpoint and check that last checkpoints are
1650      # restored correctly.
1651      for fname in gfile.Glob("{}*".format(s1)):
1652        gfile.Remove(fname)
1653      self.assertFalse(checkpoint_management.checkpoint_exists(s1))
1654
1655      # Create another saver and recover last checkpoints. The removed
1656      # checkpoint would be correctly omitted.
1657      save3 = saver_module.Saver({"v": v}, max_to_keep=10)
1658      self.assertEqual([], save3.last_checkpoints)
1659      save3.recover_last_checkpoints([s1, s2, s3])
1660      self.assertEqual([s2, s3], save3.last_checkpoints)
1661      s4 = save3.save(None, os.path.join(save_dir, "ckpt-4"))
1662      self.assertCheckpointState(
1663          model_checkpoint_path=s4,
1664          all_model_checkpoint_paths=[s2, s3, s4],
1665          save_dir=save_dir)
1666
1667
1668class KeepCheckpointEveryNHoursTest(test.TestCase):
1669
1670  def _get_test_dir(self, dirname):
1671    test_dir = os.path.join(self.get_temp_dir(), dirname)
1672    gfile.MakeDirs(test_dir)
1673    return test_dir
1674
1675  @test_util.run_in_graph_and_eager_modes
1676  @test.mock.patch.object(saver_module, "time")
1677  def testNonSharded(self, mock_time):
1678    save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")
1679
1680    with self.cached_session() as sess:
1681      v = variable_scope.variable([10.0], name="v")
1682      # Run the initializer NOW to avoid the 0.5s overhead of the first Run()
1683      # call, which throws the test timing off in fastbuild mode.
1684      self.evaluate(variables.global_variables_initializer())
1685      # Create a saver that will keep the last 2 checkpoints plus one every 0.7
1686      # seconds.
1687      start_time = time.time()
1688      mock_time.time.return_value = start_time
1689      save = saver_module.Saver(
1690          {
1691              "v": v
1692          }, max_to_keep=2, keep_checkpoint_every_n_hours=0.7 / 3600)
1693      self.assertEqual([], save.last_checkpoints)
1694
1695      # Wait till 1 seconds have elapsed so s1 will be old enough to keep.
1696      # sleep may return early, don't trust it.
1697      mock_time.time.return_value = start_time + 1.0
1698      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1699      self.assertEqual([s1], save.last_checkpoints)
1700
1701      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1702      self.assertEqual([s1, s2], save.last_checkpoints)
1703
1704      # We now have 2 'last_checkpoints': [s1, s2].  The next call to Save(),
1705      # would normally delete s1, because max_to_keep is 2.  However, s1 is
1706      # older than 0.7s so we must keep it.
1707      s3 = save.save(sess, os.path.join(save_dir, "s3"))
1708      self.assertEqual([s2, s3], save.last_checkpoints)
1709
1710      # s1 should still be here, we are Not checking now to reduce time
1711      # variance in the test.
1712
1713      # We now have 2 'last_checkpoints': [s2, s3], and s1 on disk.  The next
1714      # call to Save(), will delete s2, because max_to_keep is 2, and because
1715      # we already kept the old s1. s2 is very close in time to s1 so it gets
1716      # deleted.
1717      s4 = save.save(sess, os.path.join(save_dir, "s4"))
1718      self.assertEqual([s3, s4], save.last_checkpoints)
1719
1720      # Check that s1 is still here, but s2 is gone.
1721      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1722      self.assertFalse(checkpoint_management.checkpoint_exists(s2))
1723      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
1724      self.assertTrue(checkpoint_management.checkpoint_exists(s4))
1725
1726
1727class SaveRestoreWithVariableNameMap(test.TestCase):
1728
1729  def _testNonReshape(self, variable_op):
1730    save_path = os.path.join(self.get_temp_dir(), "non_reshape")
1731
1732    with self.session(graph=ops_lib.Graph()) as sess:
1733      # Build a graph with 2 parameter nodes, and Save and
1734      # Restore nodes for them.
1735      v0 = variable_op(10.0, name="v0")
1736      v1 = variable_op(20.0, name="v1")
1737      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
1738      self.evaluate(variables.global_variables_initializer())
1739
1740      # Check that the parameter nodes have been initialized.
1741      self.assertEqual(10.0, self.evaluate(v0))
1742      self.assertEqual(20.0, self.evaluate(v1))
1743
1744      # Save the initialized values in the file at "save_path"
1745      # Use a variable name map to set the saved tensor names
1746      val = save.save(sess, save_path)
1747      self.assertTrue(isinstance(val, six.string_types))
1748      self.assertEqual(save_path, val)
1749
1750      # Verify that the original names are not in the Saved file
1751      save = saver_module.Saver({"v0": v0, "v1": v1})
1752      with self.assertRaisesOpError("not found in checkpoint"):
1753        save.restore(sess, save_path)
1754
1755    # Verify that the mapped names are present in the Saved file and can be
1756    # Restored using remapped names.
1757    with self.session(graph=ops_lib.Graph()) as sess:
1758      v0 = variable_op(-1.0, name="v0")
1759      v1 = variable_op(-1.0, name="v1")
1760
1761      if not context.executing_eagerly():
1762        with self.assertRaisesOpError("uninitialized"):
1763          self.evaluate(v0)
1764        with self.assertRaisesOpError("uninitialized"):
1765          self.evaluate(v1)
1766
1767      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
1768      save.restore(sess, save_path)
1769
1770      # Check that the parameter nodes have been restored.
1771      if not context.executing_eagerly():
1772        self.assertEqual(10.0, self.evaluate(v0))
1773        self.assertEqual(20.0, self.evaluate(v1))
1774
1775    # Add a prefix to the node names in the current graph and Restore using
1776    # remapped names.
1777    with self.session(graph=ops_lib.Graph()) as sess:
1778      v0 = variable_op(-1.0, name="restore_prefix/v0")
1779      v1 = variable_op(-1.0, name="restore_prefix/v1")
1780
1781      if not context.executing_eagerly():
1782        with self.assertRaisesOpError("uninitialized"):
1783          self.evaluate(v0)
1784        with self.assertRaisesOpError("uninitialized"):
1785          self.evaluate(v1)
1786
1787      # Restore the saved values in the parameter nodes.
1788      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
1789      save.restore(sess, save_path)
1790
1791      # Check that the parameter nodes have been restored.
1792      self.assertEqual(10.0, self.evaluate(v0))
1793      self.assertEqual(20.0, self.evaluate(v1))
1794
1795  @test_util.run_in_graph_and_eager_modes
1796  def testNonReshapeResourceVariable(self):
1797    self._testNonReshape(resource_variable_ops.ResourceVariable)
1798
1799  def testNonReshapeVariable(self):
1800    self._testNonReshape(variables.Variable)
1801
1802
1803class MetaGraphTest(test.TestCase):
1804
1805  def _get_test_dir(self, dirname):
1806    test_dir = os.path.join(self.get_temp_dir(), dirname)
1807    gfile.MakeDirs(test_dir)
1808    return test_dir
1809
1810  @test_util.run_v1_only(
1811      "Queue-based input pipelines have been replaced by `tf.data` "
1812      "and not supported in V2.")
1813  def testAddCollectionDef(self):
1814    test_dir = self._get_test_dir("good_collection")
1815    filename = os.path.join(test_dir, "metafile")
1816    with self.cached_session():
1817      # Creates a graph.
1818      v0 = variables.VariableV1(1.0, name="v0")
1819      control_flow_ops.cond(
1820          math_ops.less(v0, 10), lambda: math_ops.add(v0, 1),
1821          lambda: math_ops.subtract(v0, 1))
1822      control_flow_ops.while_loop(lambda i: math_ops.less(i, 10),
1823                                  lambda i: math_ops.add(i, 1), [v0])
1824      var = variables.VariableV1(constant_op.constant(0, dtype=dtypes.int64))
1825      count_up_to = var.count_up_to(3)
1826      input_queue = data_flow_ops.FIFOQueue(
1827          30, dtypes.float32, shared_name="collection_queue")
1828      qr = queue_runner_impl.QueueRunner(input_queue, [count_up_to])
1829      variables.global_variables_initializer()
1830      # Creates a saver.
1831      save = saver_module.Saver({"v0": v0})
1832      # Adds a set of collections.
1833      ops_lib.add_to_collection("int_collection", 3)
1834      ops_lib.add_to_collection("float_collection", 3.5)
1835      ops_lib.add_to_collection("string_collection", "hello")
1836      ops_lib.add_to_collection("variable_collection", v0)
1837      # Add QueueRunners.
1838      queue_runner_impl.add_queue_runner(qr)
1839      # Adds user_defined proto in three formats: string, bytes and Any.
1840      queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue")
1841      ops_lib.add_to_collection("user_defined_string_collection",
1842                                str(queue_runner))
1843      ops_lib.add_to_collection("user_defined_bytes_collection",
1844                                queue_runner.SerializeToString())
1845      any_buf = Any()
1846      any_buf.Pack(queue_runner)
1847      ops_lib.add_to_collection("user_defined_any_collection", any_buf)
1848
1849      # Generates MetaGraphDef.
1850      meta_graph_def = save.export_meta_graph(filename)
1851      self.assertTrue(meta_graph_def.HasField("saver_def"))
1852      self.assertTrue(meta_graph_def.HasField("graph_def"))
1853      self.assertTrue(meta_graph_def.HasField("meta_info_def"))
1854      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "")
1855      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version,
1856                          "")
1857      collection_def = meta_graph_def.collection_def
1858      self.assertEqual(len(collection_def), 12)
1859
1860    with ops_lib.Graph().as_default():
1861      # Restores from MetaGraphDef.
1862      new_saver = saver_module.import_meta_graph(filename)
1863      # Generates a new MetaGraphDef.
1864      new_meta_graph_def = new_saver.export_meta_graph()
1865      # It should be the same as the original.
1866
1867    test_util.assert_meta_graph_protos_equal(
1868        self, meta_graph_def, new_meta_graph_def)
1869
1870  def testAddCollectionDefFails(self):
1871    with self.cached_session():
1872      # Creates a graph.
1873      v0 = variables.VariableV1(10.0, name="v0")
1874      # Creates a saver.
1875      save = saver_module.Saver({"v0": v0})
1876      # Generates MetaGraphDef.
1877      meta_graph_def = meta_graph_pb2.MetaGraphDef()
1878
1879      # Verifies that collection with unsupported key will not be added.
1880      ops_lib.add_to_collection(save, 3)
1881      save._add_collection_def(meta_graph_def, save)
1882      self.assertEqual(len(meta_graph_def.collection_def), 0)
1883
1884      # Verifies that collection where item type does not match expected
1885      # type will not be added.
1886      ops_lib.add_to_collection("int_collection", 3)
1887      ops_lib.add_to_collection("int_collection", 3.5)
1888      save._add_collection_def(meta_graph_def, "int_collection")
1889      self.assertEqual(len(meta_graph_def.collection_def), 0)
1890
1891  def _testMultiSaverCollectionSave(self, test_dir):
1892    filename = os.path.join(test_dir, "metafile")
1893    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
1894    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
1895    with self.session(graph=ops_lib.Graph()) as sess:
1896      # Creates a graph.
1897      v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
1898      v1 = variables.VariableV1(11.0, name="v1")
1899      # Creates 2 savers.
1900      saver0 = saver_module.Saver({"v0": v0}, name="saver0")
1901      saver1 = saver_module.Saver({"v1": v1}, name="saver1")
1902      ops_lib.add_to_collection("savers", saver0)
1903      ops_lib.add_to_collection("savers", saver1)
1904      self.evaluate(variables.global_variables_initializer())
1905      # Saves to different checkpoints.
1906      saver0.save(sess, saver0_ckpt)
1907      saver1.save(sess, saver1_ckpt)
1908      # Generates MetaGraphDef.
1909      meta_graph_def = saver_module.export_meta_graph(filename)
1910      meta_graph_def0 = saver0.export_meta_graph()
1911      meta_graph_def1 = saver1.export_meta_graph()
1912
1913      # Verifies that there is no saver_def in meta_graph_def.
1914      self.assertFalse(meta_graph_def.HasField("saver_def"))
1915      # Verifies that there is saver_def in meta_graph_def0 and 1.
1916      self.assertTrue(meta_graph_def0.HasField("saver_def"))
1917      self.assertTrue(meta_graph_def1.HasField("saver_def"))
1918
1919      # Verifies SAVERS is saved as bytes_list for meta_graph_def.
1920      collection_def = meta_graph_def.collection_def["savers"]
1921      kind = collection_def.WhichOneof("kind")
1922      self.assertEqual(kind, "bytes_list")
1923      # Verifies that there are 2 entries in SAVERS collection.
1924      savers = getattr(collection_def, kind)
1925      self.assertEqual(2, len(savers.value))
1926
1927      # Verifies SAVERS collection is saved as bytes_list for meta_graph_def0.
1928      collection_def = meta_graph_def0.collection_def["savers"]
1929      kind = collection_def.WhichOneof("kind")
1930      self.assertEqual(kind, "bytes_list")
1931      # Verifies that there are 2 entries in SAVERS collection.
1932      savers = getattr(collection_def, kind)
1933      self.assertEqual(2, len(savers.value))
1934
1935  def _testMultiSaverCollectionRestore(self, test_dir):
1936    filename = os.path.join(test_dir, "metafile")
1937    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
1938    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
1939    with self.session(graph=ops_lib.Graph()) as sess:
1940      # Imports from meta_graph.
1941      saver_module.import_meta_graph(filename)
1942      # Retrieves SAVERS collection. Verifies there are 2 entries.
1943      savers = ops_lib.get_collection("savers")
1944      self.assertEqual(2, len(savers))
1945      # Retrieves saver0. Verifies that new_saver0 can restore v0, but not v1.
1946      new_saver0 = savers[0]
1947      new_saver0.restore(sess, saver0_ckpt)
1948      v0 = sess.graph.get_tensor_by_name("v0:0")
1949      v1 = sess.graph.get_tensor_by_name("v1:0")
1950      self.assertAllEqual([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
1951                          self.evaluate(v0))
1952      self.assertEqual([3, 2], v0.get_shape())
1953      self.assertEqual([], v1.get_shape())
1954      with self.assertRaisesWithPredicateMatch(
1955          errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):
1956        self.evaluate(v1)
1957      # Retrieves saver1. Verifies that new_saver1 can restore v1.
1958      new_saver1 = savers[1]
1959      new_saver1.restore(sess, saver1_ckpt)
1960      v1 = sess.graph.get_tensor_by_name("v1:0")
1961      self.assertEqual(11.0, self.evaluate(v1))
1962
1963  @test_util.run_v1_only(
1964      "Exporting/importing meta graphs is only supported in V1.")
1965  def testMultiSaverCollection(self):
1966    test_dir = self._get_test_dir("saver_collection")
1967    self._testMultiSaverCollectionSave(test_dir)
1968    self._testMultiSaverCollectionRestore(test_dir)
1969
1970  @test_util.run_v1_only(
1971      "Exporting/importing meta graphs is only supported in V1.")
1972  def testClearExtraneousSavers(self):
1973    test_dir = self._get_test_dir("clear_extraneous_savers")
1974    filename = os.path.join(test_dir, "metafile")
1975    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
1976    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
1977    with self.session(graph=ops_lib.Graph()) as sess:
1978      # Creates a graph.
1979      v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
1980      v1 = variables.VariableV1(11.0, name="v1")
1981
1982      # Creates 2 savers.
1983      saver0 = saver_module.Saver({"v0": v0}, name="saver0")
1984      saver1 = saver_module.Saver({"v1": v1}, name="saver1")
1985      ops_lib.add_to_collection("savers", saver0)
1986      ops_lib.add_to_collection("savers", saver1)
1987      self.evaluate(variables.global_variables_initializer())
1988
1989      # Saves to different checkpoints.
1990      saver0.save(sess, saver0_ckpt)
1991      saver1.save(sess, saver1_ckpt)
1992
1993      # Generates MetaGraphDef.
1994      meta_graph_def = saver_module.export_meta_graph(filename)
1995      meta_graph_def0 = saver0.export_meta_graph()
1996      meta_graph_def1 = saver1.export_meta_graph(clear_extraneous_savers=True)
1997
1998      # Verifies that there is no saver_def in meta_graph_def.
1999      self.assertFalse(meta_graph_def.HasField("saver_def"))
2000      # Verifies that there is saver_def in meta_graph_def0 and 1.
2001      self.assertTrue(meta_graph_def0.HasField("saver_def"))
2002      self.assertTrue(meta_graph_def1.HasField("saver_def"))
2003
2004      # Verifies SAVERS is saved as bytes_list for meta_graph_def.
2005      collection_def = meta_graph_def.collection_def["savers"]
2006      kind = collection_def.WhichOneof("kind")
2007      self.assertEqual(kind, "bytes_list")
2008
2009      # Verifies that there are 2 entries in SAVERS collection.
2010      savers = getattr(collection_def, kind)
2011      self.assertEqual(2, len(savers.value))
2012
2013      # Verifies SAVERS collection is saved as bytes_list for meta_graph_def1.
2014      collection_def = meta_graph_def1.collection_def["savers"]
2015      kind = collection_def.WhichOneof("kind")
2016      self.assertEqual(kind, "bytes_list")
2017
2018      # Verifies that there is 1 entry in SAVERS collection.
2019      savers = getattr(collection_def, kind)
2020      self.assertEqual(1, len(savers.value))
2021
2022      # Verifies that saver0 graph nodes are omitted from the saver1 export
2023      self.assertEqual(33, len(meta_graph_def0.graph_def.node))
2024      self.assertEqual(21, len(meta_graph_def1.graph_def.node))
2025
2026  def testBinaryAndTextFormat(self):
2027    test_dir = self._get_test_dir("binary_and_text")
2028    filename = os.path.join(test_dir, "metafile")
2029    # train.Saver is V1 only API.
2030    with ops_lib.Graph().as_default(), self.session():
2031      # Creates a graph.
2032      variables.VariableV1(10.0, name="v0")
2033      # Exports the graph as binary format.
2034      saver_module.export_meta_graph(filename, as_text=False)
2035    with ops_lib.Graph().as_default(), self.session():
2036      # Imports the binary format graph.
2037      saver = saver_module.import_meta_graph(filename)
2038      self.assertIsNotNone(saver)
2039      # Exports the graph as text format.
2040      saver.export_meta_graph(filename, as_text=True)
2041    with ops_lib.Graph().as_default(), self.session():
2042      # Imports the text format graph.
2043      saver_module.import_meta_graph(filename)
2044      # Writes wrong contents to the file.
2045      graph_io.write_graph(saver.as_saver_def(),
2046                           os.path.dirname(filename),
2047                           os.path.basename(filename))
2048    with ops_lib.Graph().as_default(), self.session():
2049      # Import should fail.
2050      with self.assertRaisesWithPredicateMatch(IOError,
2051                                               lambda e: "Cannot parse file"):
2052        saver_module.import_meta_graph(filename)
2053      # Deletes the file
2054      gfile.Remove(filename)
2055      with self.assertRaisesWithPredicateMatch(IOError,
2056                                               lambda e: "does not exist"):
2057        saver_module.import_meta_graph(filename)
2058
2059  @test_util.run_v1_only(
2060      "Exporting/importing meta graphs is only supported in V1.")
2061  def testSliceVariable(self):
2062    test_dir = self._get_test_dir("slice_saver")
2063    filename = os.path.join(test_dir, "metafile")
2064    with self.cached_session():
2065      v1 = variables.VariableV1([20.0], name="v1")
2066      v2 = variables.VariableV1([20.0], name="v2")
2067      v2._set_save_slice_info(
2068          variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))
2069
2070      # The names are different and will work.
2071      slice_saver = saver_module.Saver({"first": v1, "second": v2})
2072      self.evaluate(variables.global_variables_initializer())
2073      # Exports to meta_graph
2074      meta_graph_def = slice_saver.export_meta_graph(filename)
2075
2076    with ops_lib.Graph().as_default():
2077      # Restores from MetaGraphDef.
2078      new_saver = saver_module.import_meta_graph(filename)
2079      self.assertIsNotNone(new_saver)
2080      # Generates a new MetaGraphDef.
2081      new_meta_graph_def = new_saver.export_meta_graph()
2082      # It should be the same as the original.
2083      test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
2084                                               new_meta_graph_def)
2085
2086  def _testGraphExtensionSave(self, test_dir):
2087    filename = os.path.join(test_dir, "metafile")
2088    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
2089    # Creates an inference graph.
2090    # Hidden 1
2091    images = constant_op.constant(1.2, dtypes.float32, shape=[100, 28])
2092    with ops_lib.name_scope("hidden1"):
2093      weights = variables.VariableV1(
2094          random_ops.truncated_normal(
2095              [28, 128], stddev=1.0 / math.sqrt(float(28))),
2096          name="weights")
2097      # The use of control_flow_ops.cond here is purely for adding test coverage
2098      # the save and restore of control flow context (which doesn't make any
2099      # sense here from a machine learning perspective).  The typical biases is
2100      # a simple Variable without the conditions.
2101      biases = variables.VariableV1(
2102          control_flow_ops.cond(
2103              math_ops.less(random.random(), 0.5),
2104              lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
2105          name="biases")
2106      hidden1 = nn_ops.relu(math_ops.matmul(images, weights) + biases)
2107    # Hidden 2
2108    with ops_lib.name_scope("hidden2"):
2109      weights = variables.VariableV1(
2110          random_ops.truncated_normal(
2111              [128, 32], stddev=1.0 / math.sqrt(float(128))),
2112          name="weights")
2113
2114      # The use of control_flow_ops.while_loop here is purely for adding test
2115      # coverage the save and restore of control flow context (which doesn't
2116      # make any sense here from a machine learning perspective).  The typical
2117      # biases is a simple Variable without the conditions.
2118      def loop_cond(it, _):
2119        return it < 2
2120
2121      def loop_body(it, biases):
2122        biases += constant_op.constant(0.1, shape=[32])
2123        return it + 1, biases
2124
2125      _, biases = control_flow_ops.while_loop(
2126          loop_cond, loop_body,
2127          [constant_op.constant(0),
2128           variables.VariableV1(array_ops.zeros([32]))])
2129      hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
2130    # Linear
2131    with ops_lib.name_scope("softmax_linear"):
2132      weights = variables.VariableV1(
2133          random_ops.truncated_normal(
2134              [32, 10], stddev=1.0 / math.sqrt(float(32))),
2135          name="weights")
2136      biases = variables.VariableV1(array_ops.zeros([10]), name="biases")
2137      logits = math_ops.matmul(hidden2, weights) + biases
2138      ops_lib.add_to_collection("logits", logits)
2139    init_all_op = variables.global_variables_initializer()
2140
2141    with self.cached_session() as sess:
2142      # Initializes all the variables.
2143      self.evaluate(init_all_op)
2144      # Runs to logit.
2145      self.evaluate(logits)
2146      # Creates a saver.
2147      saver0 = saver_module.Saver()
2148      saver0.save(sess, saver0_ckpt)
2149      # Generates MetaGraphDef.
2150      saver0.export_meta_graph(filename)
2151
2152  def _testGraphExtensionRestore(self, test_dir):
2153    filename = os.path.join(test_dir, "metafile")
2154    train_filename = os.path.join(test_dir, "train_metafile")
2155    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
2156    with self.session(graph=ops_lib.Graph()) as sess:
2157      # Restores from MetaGraphDef.
2158      new_saver = saver_module.import_meta_graph(filename)
2159      # Generates a new MetaGraphDef.
2160      new_saver.export_meta_graph()
2161      # Restores from checkpoint.
2162      new_saver.restore(sess, saver0_ckpt)
2163      # Adds loss and train.
2164      labels = constant_op.constant(0, dtypes.int32, shape=[100], name="labels")
2165      batch_size = array_ops.size(labels)
2166      labels = array_ops.expand_dims(labels, 1)
2167      indices = array_ops.expand_dims(math_ops.range(0, batch_size), 1)
2168      concated = array_ops.concat([indices, labels], 1)
2169      onehot_labels = sparse_ops.sparse_to_dense(
2170          concated, array_ops.stack([batch_size, 10]), 1.0, 0.0)
2171      logits = ops_lib.get_collection("logits")[0]
2172      cross_entropy = nn_ops.softmax_cross_entropy_with_logits(
2173          labels=onehot_labels, logits=logits, name="xentropy")
2174      loss = math_ops.reduce_mean(cross_entropy, name="xentropy_mean")
2175
2176      summary.scalar("loss", loss)
2177      # Creates the gradient descent optimizer with the given learning rate.
2178      optimizer = gradient_descent.GradientDescentOptimizer(0.01)
2179
2180      # Runs train_op.
2181      train_op = optimizer.minimize(loss)
2182      ops_lib.add_to_collection("train_op", train_op)
2183
2184      # Runs train_op.
2185      self.evaluate(train_op)
2186
2187      # Generates MetaGraphDef.
2188      saver_module.export_meta_graph(train_filename)
2189
2190  def _testRestoreFromTrainGraphWithControlContext(self, test_dir):
2191    train_filename = os.path.join(test_dir, "train_metafile")
2192    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
2193    with self.session(graph=ops_lib.Graph()) as sess:
2194      # Restores from MetaGraphDef.
2195      new_saver = saver_module.import_meta_graph(train_filename)
2196      # Restores from checkpoint.
2197      new_saver.restore(sess, saver0_ckpt)
2198      train_op = ops_lib.get_collection("train_op")[0]
2199      self.evaluate(train_op)
2200
2201  def testGraphExtension(self):
2202    test_dir = self._get_test_dir("graph_extension")
2203    # train.Saver and train.import_meta_graph are V1 only APIs.
2204    with ops_lib.Graph().as_default():
2205      self._testGraphExtensionSave(test_dir)
2206      self._testGraphExtensionRestore(test_dir)
2207      self._testRestoreFromTrainGraphWithControlContext(test_dir)
2208
2209  def _testGradientSerDes(self, graph_fn):
2210    """Tests that gradients can be computed after exporting and importing.
2211
2212    Builds a graph, exports it, and verifies that it can be imported and the
2213    gradient can be built and run correctly.
2214
2215    Args:
2216      graph_fn: takes a single float Tensor argument as input, outputs a single
2217        Tensor
2218    """
2219    test_dir = self._get_test_dir("nested_control_flow")
2220    filename = os.path.join(test_dir, "metafile")
2221    saver_ckpt = os.path.join(test_dir, "saver.ckpt")
2222
2223    # Create while loop using `outer_body_fn`.
2224    with ops_lib.Graph().as_default():
2225      var = variables.VariableV1(0.0)
2226      var_name = var.name
2227      output = graph_fn(var)
2228      output_name = output.name
2229      init_op = variables.global_variables_initializer()
2230
2231      # Generate a MetaGraphDef containing the while loop.
2232      with session.Session() as sess:
2233        self.evaluate(init_op)
2234        self.evaluate(output)
2235        saver = saver_module.Saver()
2236        saver.save(sess, saver_ckpt)
2237        saver.export_meta_graph(filename)
2238
2239      # Build and run the gradients of the while loop. We use this below to
2240      # verify that the gradients are correct with an imported MetaGraphDef.
2241      grad = gradients_impl.gradients([output], [var])
2242      # Turn off constant folding to avoid breaking testNestedControlFlowSerDes.
2243      # It appears that a missing control dependency in the gradient graph
2244      # causes the fetch node to not be triggered.
2245      no_constfold_config = config_pb2.ConfigProto()
2246      no_constfold_config.graph_options.rewrite_options.constant_folding = (
2247          rewriter_config_pb2.RewriterConfig.OFF)
2248      with session.Session(config=no_constfold_config) as sess:
2249        self.evaluate(init_op)
2250        expected_grad_value = self.evaluate(grad)
2251
2252    # Restore the MetaGraphDef into a new Graph.
2253    with ops_lib.Graph().as_default():
2254      with session.Session() as sess:
2255        saver = saver_module.import_meta_graph(filename)
2256        saver.restore(sess, saver_ckpt)
2257
2258      # Make sure we can still build gradients and get the same result.
2259      var = ops_lib.get_default_graph().get_tensor_by_name(var_name)
2260      output = ops_lib.get_default_graph().get_tensor_by_name(output_name)
2261      grad = gradients_impl.gradients([output], [var])
2262
2263      init_op = variables.global_variables_initializer()
2264
2265      with session.Session(config=no_constfold_config) as sess:
2266        self.evaluate(init_op)
2267        actual_grad_value = self.evaluate(grad)
2268        self.assertEqual(expected_grad_value, actual_grad_value)
2269
2270  def _testWhileLoopAndGradientSerDes(self, outer_body_fn):
2271    # Build a while loop with `outer_body_fn`, export it, and verify that it can
2272    # be imported and the gradient can be built and run correctly.
2273    # pylint: disable=g-long-lambda
2274    return self._testGradientSerDes(
2275        lambda x: control_flow_ops.while_loop(
2276            lambda i, y: i < 5, outer_body_fn, [0, x])[1])
2277    # pylint: enable=g-long-lambda
2278
2279  def testNestedWhileLoopsSerDes(self):
2280    # Test two simple nested while loops.
2281    def body(i, x):
2282      _, r = control_flow_ops.while_loop(lambda j, y: j < 3,
2283                                         lambda j, y: (j + 1, y + x),
2284                                         [0, 0.0])
2285      return i + 1, x + r
2286    self._testWhileLoopAndGradientSerDes(body)
2287
2288  def testNestedControlFlowSerDes(self):
2289    # Test while loop in a cond in a while loop.
2290    # pylint: disable=g-long-lambda
2291    def body(i, x):
2292      cond_result = control_flow_ops.cond(
2293          i > 0,
2294          lambda: control_flow_ops.while_loop(
2295              lambda j, y: j < 3,
2296              lambda j, y: (j + 1, y + x),
2297              [0, 0.0])[1],
2298          lambda: x)
2299      return i + 1, cond_result
2300    # pylint: enable=g-long-lambda
2301    self._testWhileLoopAndGradientSerDes(body)
2302
2303  def testNestedCondsSerDes(self):
2304    # Test conds in a cond.
2305    # pylint: disable=g-long-lambda
2306    self._testGradientSerDes(lambda x: control_flow_ops.cond(
2307        x > 0,
2308        lambda: control_flow_ops.cond(x > 3,
2309                                      lambda: array_ops.identity(x),
2310                                      lambda: math_ops.multiply(x, 2.0)),
2311        lambda: control_flow_ops.cond(x < -3,
2312                                      lambda: constant_op.constant(1.0),
2313                                      lambda: math_ops.multiply(x, -1.0))))
2314    # pylint: enable=g-long-lambda
2315
2316  @test_util.run_v1_only("This exercises Tensor.op which is meaningless in V2.")
2317  def testStrippedOpListDef(self):
2318    with self.cached_session():
2319      # Creates a graph.
2320      v0 = variables.VariableV1(0.0)
2321      var = variables.VariableV1(10.0)
2322      math_ops.add(v0, var)
2323
2324      @function.Defun(dtypes.float32)
2325      def minus_one(x):
2326        return x - 1
2327
2328      minus_one(array_ops.identity(v0))
2329      save = saver_module.Saver({"v0": v0})
2330      variables.global_variables_initializer()
2331
2332      # Generates MetaGraphDef.
2333      meta_graph_def = save.export_meta_graph()
2334      ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op]
2335      if save._write_version is saver_pb2.SaverDef.V1:
2336        self.assertEqual(ops, [
2337            "Add", "Assign", "Const", "Identity", "NoOp",
2338            "PlaceholderWithDefault", "RestoreV2", "SaveSlices", "Sub",
2339            "VariableV2"
2340        ])
2341      else:
2342        self.assertEqual(ops, [
2343            "Add", "Assign", "Const", "Identity", "NoOp",
2344            "PlaceholderWithDefault", "RestoreV2", "SaveV2", "Sub", "VariableV2"
2345        ])
2346
2347      # Test calling stripped_op_list_for_graph directly
2348      op_list = meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)
2349      self.assertEqual(ops, [o.name for o in op_list.op])
2350      for o in op_list.op:
2351        self.assertEqual(o.summary, "")
2352        self.assertEqual(o.description, "")
2353
2354  def testStripDefaultValuedAttrs(self):
2355    """Verifies that default valued attrs are stripped, unless disabled."""
2356
2357    # With strip_default_attrs enabled, attributes "T" (float32) and "Tout"
2358    # (complex64) in the "Complex" op must be removed.
2359    # train.Saver and train.export_meta_graph are V1 only APIs.
2360    with ops_lib.Graph().as_default(), self.cached_session():
2361      real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
2362      imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
2363      math_ops.complex(real_num, imag_num, name="complex")
2364
2365      save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
2366      variables.global_variables_initializer()
2367
2368      meta_graph_def = save.export_meta_graph(strip_default_attrs=True)
2369      node_def = test_util.get_node_def_from_graph("complex",
2370                                                   meta_graph_def.graph_def)
2371      self.assertNotIn("T", node_def.attr)
2372      self.assertNotIn("Tout", node_def.attr)
2373
2374    # With strip_default_attrs disabled, attributes "T" (float32) and "Tout"
2375    # (complex64) in the "Complex" op must *not* be removed, even if they map
2376    # to their defaults.
2377    with ops_lib.Graph().as_default(), self.session():
2378      real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
2379      imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
2380      math_ops.complex(real_num, imag_num, name="complex")
2381
2382      save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
2383      variables.global_variables_initializer()
2384
2385      meta_graph_def = save.export_meta_graph(strip_default_attrs=False)
2386      node_def = test_util.get_node_def_from_graph("complex",
2387                                                   meta_graph_def.graph_def)
2388      self.assertIn("T", node_def.attr)
2389      self.assertIn("Tout", node_def.attr)
2390
2391  def testImportIntoNamescope(self):
2392    # Test that we can import a meta graph into a namescope.
2393    test_dir = self._get_test_dir("import_into_namescope")
2394    filename = os.path.join(test_dir, "ckpt")
2395    # train.Saver is V1 only API.
2396    with ops_lib.Graph().as_default():
2397      image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
2398      label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
2399      with session.Session() as sess:
2400        weights = variables.VariableV1(
2401            random_ops.random_uniform([784, 10]), name="weights")
2402        bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
2403        logit = nn_ops.relu(
2404            math_ops.matmul(image, weights) + bias, name="logits")
2405        nn_ops.softmax(logit, name="prediction")
2406        cost = nn_ops.softmax_cross_entropy_with_logits(
2407            labels=label, logits=logit, name="cost")
2408        adam.AdamOptimizer().minimize(cost, name="optimize")
2409        saver = saver_module.Saver()
2410        self.evaluate(variables.global_variables_initializer())
2411        saver.save(sess, filename)
2412
2413    graph = ops_lib.Graph()
2414    with session.Session(graph=graph) as sess:
2415      new_saver = saver_module.import_meta_graph(
2416          filename + ".meta", graph=graph, import_scope="new_model")
2417      new_saver.restore(sess, filename)
2418      sess.run(["new_model/optimize"], {
2419          "new_model/image:0": np.random.random([1, 784]),
2420          "new_model/label:0": np.random.randint(
2421              10, size=[1, 10])
2422      })
2423
2424  def testImportIntoNamescopeWithoutVariables(self):
2425    # Save a simple graph that contains no variables into a checkpoint.
2426    test_dir = self._get_test_dir("no_vars_graph")
2427    filename = os.path.join(test_dir, "ckpt")
2428    graph_1 = ops_lib.Graph()
2429    with session.Session(graph=graph_1) as sess:
2430      constant_op.constant([1, 2, 3], name="x")
2431      constant_op.constant([1, 2, 3], name="y")
2432      saver = saver_module.Saver(allow_empty=True)
2433      saver.save(sess, filename)
2434
2435    # Create a fresh graph.
2436    graph_2 = ops_lib.Graph()
2437    with session.Session(graph=graph_2) as sess:
2438      # Restore the above checkpoint under scope "subgraph_1".
2439      new_saver_1 = saver_module.import_meta_graph(
2440          filename + ".meta", graph=graph_2, import_scope="subgraph_1")
2441      # There are no variables to restore, so import_meta_graph should not
2442      # return a Saver.
2443      self.assertIsNone(new_saver_1)
2444
2445      # Create a variable in graph_2 under scope "my_scope".
2446      variables.VariableV1(array_ops.zeros([10]), name="my_scope/my_var")
2447      self.evaluate(variables.global_variables_initializer())
2448      # Restore the checkpoint into a different scope "subgraph_2".
2449      new_saver_2 = saver_module.import_meta_graph(
2450          filename + ".meta", graph=graph_2, import_scope="subgraph_2")
2451      # Because the variable does not live in scope "subgraph_2",
2452      # import_meta_graph should not attempt to restore the variable. So,
2453      # import_meta_graph still won't return a Saver instance.
2454      self.assertIsNone(new_saver_2)
2455
2456      # However, if we restore the checkpoint under scope "my_scope",
2457      # import_meta_graph will detect the variable and return a Saver for
2458      # restoring it. This should happen even when the variable does not
2459      # originate from graph_1.
2460      new_saver_3 = saver_module.import_meta_graph(
2461          filename + ".meta", graph=graph_2, import_scope="my_scope")
2462      self.assertIsInstance(new_saver_3, saver_module.Saver)
2463
2464  def testImportIntoImplicitNamescope(self):
2465    # Test that we can import a meta graph into an implicit namescope.
2466    test_dir = self._get_test_dir("import_into_namescope")
2467    filename = os.path.join(test_dir, "ckpt")
2468    # train.Saver is V1 only API.
2469    with ops_lib.Graph().as_default():
2470      image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
2471      label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
2472      with session.Session() as sess:
2473        weights = variables.VariableV1(
2474            random_ops.random_uniform([784, 10]), name="weights")
2475        bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
2476        logit = nn_ops.relu(
2477            math_ops.matmul(image, weights) + bias, name="logits")
2478        nn_ops.softmax(logit, name="prediction")
2479        cost = nn_ops.softmax_cross_entropy_with_logits(
2480            labels=label, logits=logit, name="cost")
2481        adam.AdamOptimizer().minimize(cost, name="optimize")
2482        saver = saver_module.Saver()
2483        self.evaluate(variables.global_variables_initializer())
2484        saver.save(sess, filename)
2485
2486    graph = ops_lib.Graph()
2487    with session.Session(graph=graph) as sess:
2488      with ops_lib.name_scope("new_model"):
2489        new_saver = saver_module.import_meta_graph(
2490            filename + ".meta", graph=graph)
2491
2492      new_saver.restore(sess, filename)
2493      sess.run(["new_model/optimize"], {
2494          "new_model/image:0": np.random.random([1, 784]),
2495          "new_model/label:0": np.random.randint(
2496              10, size=[1, 10])
2497      })
2498
2499  def testClearDevicesOnImport(self):
2500    # Test that we import a graph without its devices and run successfully.
2501    with ops_lib.Graph().as_default():
2502      with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
2503        image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
2504        label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
2505        weights = variables.VariableV1(
2506            random_ops.random_uniform([784, 10]), name="weights")
2507        bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
2508        logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
2509        nn_ops.softmax(logit, name="prediction")
2510        cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
2511                                                        logits=logit)
2512        adam.AdamOptimizer().minimize(cost, name="optimize")
2513      meta_graph_def = saver_module.export_meta_graph()
2514
2515    with session.Session(graph=ops_lib.Graph()) as sess:
2516      saver_module.import_meta_graph(
2517          meta_graph_def, clear_devices=False, import_scope="new_model")
2518      # Device refers to GPU, which is not available here.
2519      with self.assertRaises(errors_impl.InvalidArgumentError):
2520        self.evaluate(variables.global_variables_initializer())
2521
2522    with session.Session(graph=ops_lib.Graph()) as sess:
2523      saver_module.import_meta_graph(
2524          meta_graph_def, clear_devices=True, import_scope="new_model")
2525      self.evaluate(variables.global_variables_initializer())
2526      sess.run(["new_model/optimize"], {
2527          "new_model/image:0": np.random.random([1, 784]),
2528          "new_model/label:0": np.random.randint(
2529              10, size=[1, 10])
2530      })
2531
2532  def testClearDevicesOnExport(self):
2533    # Test that we export a graph without its devices and run successfully.
2534    with ops_lib.Graph().as_default():
2535      with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
2536        image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
2537        label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
2538        weights = variables.VariableV1(
2539            random_ops.random_uniform([784, 10]), name="weights")
2540        bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
2541        logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
2542        nn_ops.softmax(logit, name="prediction")
2543        cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
2544                                                        logits=logit)
2545        adam.AdamOptimizer().minimize(cost, name="optimize")
2546      meta_graph_def = saver_module.export_meta_graph(clear_devices=True)
2547      graph_io.write_graph(meta_graph_def, self.get_temp_dir(),
2548                           "meta_graph.pbtxt")
2549
2550    with session.Session(graph=ops_lib.Graph()) as sess:
2551      saver_module.import_meta_graph(meta_graph_def, import_scope="new_model")
2552      self.evaluate(variables.global_variables_initializer())
2553      sess.run(["new_model/optimize"], {
2554          "new_model/image:0": np.random.random([1, 784]),
2555          "new_model/label:0": np.random.randint(
2556              10, size=[1, 10])
2557      })
2558
2559  def testPreserveDatasetAndFunctions(self):
2560    with ops_lib.Graph().as_default() as g:
2561      dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x)
2562      iterator = dataset_ops.make_one_shot_iterator(dataset)
2563      next_element = iterator.get_next()
2564      _ = array_ops.identity(next_element, name="output")
2565
2566      # Generate three MetaGraphDef protos using different code paths.
2567      meta_graph_def_simple = saver_module.export_meta_graph()
2568      meta_graph_def_devices_cleared = saver_module.export_meta_graph(
2569          clear_devices=True)
2570      meta_graph_def_from_graph_def = saver_module.export_meta_graph(
2571          clear_devices=True, graph_def=g.as_graph_def())
2572
2573    for meta_graph_def in [meta_graph_def_simple,
2574                           meta_graph_def_devices_cleared,
2575                           meta_graph_def_from_graph_def]:
2576      with session.Session(graph=ops_lib.Graph()) as sess:
2577        saver_module.import_meta_graph(meta_graph_def, import_scope="new_model")
2578        self.evaluate(variables.global_variables_initializer())
2579        for i in range(10):
2580          self.assertEqual(i * i, sess.run("new_model/output:0"))
2581        with self.assertRaises(errors.OutOfRangeError):
2582          sess.run("new_model/output:0")
2583
2584
2585class CheckpointReaderTest(test.TestCase):
2586
2587  _WRITE_VERSION = saver_pb2.SaverDef.V1
2588
2589  def testDebugString(self):
2590    # Builds a graph.
2591    v0 = variables.VariableV1(
2592        [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
2593    v1 = variables.VariableV1(
2594        [[[1], [2]], [[3], [4]], [[5], [6]]], dtype=dtypes.float32, name="v1")
2595    init_all_op = variables.global_variables_initializer()
2596    save = saver_module.Saver(
2597        {
2598            "v0": v0,
2599            "v1": v1
2600        }, write_version=self._WRITE_VERSION)
2601    save_path = os.path.join(self.get_temp_dir(),
2602                             "ckpt_for_debug_string" + str(self._WRITE_VERSION))
2603    with self.cached_session() as sess:
2604      self.evaluate(init_all_op)
2605      # Saves a checkpoint.
2606      save.save(sess, save_path)
2607
2608      # Creates a reader.
2609      reader = py_checkpoint_reader.NewCheckpointReader(save_path)
2610      # Verifies that the tensors exist.
2611      self.assertTrue(reader.has_tensor("v0"))
2612      self.assertTrue(reader.has_tensor("v1"))
2613      debug_string = reader.debug_string()
2614      # Verifies that debug string contains the right strings.
2615      self.assertTrue(compat.as_bytes("v0 (DT_FLOAT) [2,3]") in debug_string)
2616      self.assertTrue(compat.as_bytes("v1 (DT_FLOAT) [3,2,1]") in debug_string)
2617      # Verifies get_variable_to_shape_map() returns the correct information.
2618      var_map = reader.get_variable_to_shape_map()
2619      self.assertEqual([2, 3], var_map["v0"])
2620      self.assertEqual([3, 2, 1], var_map["v1"])
2621      # Verifies get_tensor() returns the tensor value.
2622      v0_tensor = reader.get_tensor("v0")
2623      v1_tensor = reader.get_tensor("v1")
2624      self.assertAllEqual(v0, v0_tensor)
2625      self.assertAllEqual(v1, v1_tensor)
2626      # Verifies get_tensor() fails for non-existent tensors.
2627      with self.assertRaisesRegex(errors.NotFoundError,
2628                                  "v3 not found in checkpoint"):
2629        reader.get_tensor("v3")
2630
2631  def testNonexistentPath(self):
2632    with self.assertRaisesRegex(errors.NotFoundError,
2633                                "Unsuccessful TensorSliceReader"):
2634      py_checkpoint_reader.NewCheckpointReader("non-existent")
2635
2636
2637class CheckpointReaderForV2Test(CheckpointReaderTest):
2638  _WRITE_VERSION = saver_pb2.SaverDef.V2
2639
2640
2641class WriteGraphTest(test.TestCase):
2642
2643  def _get_test_dir(self, dirname):
2644    test_dir = os.path.join(self.get_temp_dir(), dirname)
2645    gfile.MakeDirs(test_dir)
2646    return test_dir
2647
2648  def testWriteGraph(self):
2649    test_dir = self._get_test_dir("write_graph_dir")
2650    variables.VariableV1(
2651        [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
2652    path = graph_io.write_graph(ops_lib.get_default_graph(),
2653                                os.path.join(test_dir, "l1"), "graph.pbtxt")
2654    truth = os.path.join(test_dir, "l1", "graph.pbtxt")
2655    self.assertEqual(path, truth)
2656    self.assertTrue(os.path.exists(path))
2657
2658  def testRecursiveCreate(self):
2659    test_dir = self._get_test_dir("deep_dir")
2660    variables.VariableV1(
2661        [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
2662    path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
2663                                os.path.join(test_dir, "l1", "l2", "l3"),
2664                                "graph.pbtxt")
2665    truth = os.path.join(test_dir, "l1", "l2", "l3", "graph.pbtxt")
2666    self.assertEqual(path, truth)
2667    self.assertTrue(os.path.exists(path))
2668
2669
2670class ScopedGraphTest(test.TestCase):
2671
2672  def _get_test_dir(self, dirname):
2673    test_dir = os.path.join(self.get_temp_dir(), dirname)
2674    gfile.MakeDirs(test_dir)
2675    return test_dir
2676
2677  def _testScopedSave(self, test_dir, exported_filename, ckpt_filename):
2678    graph = ops_lib.Graph()
2679    with graph.as_default():
2680      # Creates an inference graph.
2681      # Hidden 1
2682      images = constant_op.constant(
2683          1.2, dtypes.float32, shape=[100, 28], name="images")
2684      with ops_lib.name_scope("hidden1"):
2685        weights1 = variables.VariableV1(
2686            random_ops.truncated_normal(
2687                [28, 128], stddev=1.0 / math.sqrt(float(28))),
2688            name="weights")
2689        # The use of control_flow_ops.cond here is purely for adding test
2690        # coverage the save and restore of control flow context (which doesn't
2691        # make any sense here from a machine learning perspective).  The typical
2692        # biases is a simple Variable without the conditions.
2693        biases1 = variables.VariableV1(
2694            control_flow_ops.cond(
2695                math_ops.less(random.random(), 0.5),
2696                lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
2697            name="biases")
2698        hidden1 = nn_ops.relu(math_ops.matmul(images, weights1) + biases1)
2699
2700      # Hidden 2
2701      with ops_lib.name_scope("hidden2"):
2702        weights2 = variables.VariableV1(
2703            random_ops.truncated_normal(
2704                [128, 32], stddev=1.0 / math.sqrt(float(128))),
2705            name="weights")
2706
2707        # The use of control_flow_ops.while_loop here is purely for adding test
2708        # coverage the save and restore of control flow context (which doesn't
2709        # make any sense here from a machine learning perspective).  The typical
2710        # biases is a simple Variable without the conditions.
2711        def loop_cond(it, _):
2712          return it < 2
2713
2714        def loop_body(it, biases2):
2715          biases2 += constant_op.constant(0.1, shape=[32])
2716          return it + 1, biases2
2717
2718        _, biases2 = control_flow_ops.while_loop(loop_cond, loop_body, [
2719            constant_op.constant(0), variables.VariableV1(array_ops.zeros([32]))
2720        ])
2721        hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2)
2722      # Linear
2723      with ops_lib.name_scope("softmax_linear"):
2724        weights3 = variables.VariableV1(
2725            random_ops.truncated_normal(
2726                [32, 10], stddev=1.0 / math.sqrt(float(32))),
2727            name="weights")
2728        biases3 = variables.VariableV1(array_ops.zeros([10]), name="biases")
2729        logits = math_ops.matmul(hidden2, weights3) + biases3
2730        ops_lib.add_to_collection("logits", logits)
2731
2732        # Adds user_defined proto in three formats: string, bytes and Any.
2733        # Any proto should just pass through.
2734        queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue")
2735        ops_lib.add_to_collection("user_defined_string_collection",
2736                                  str(queue_runner))
2737        ops_lib.add_to_collection("user_defined_bytes_collection",
2738                                  queue_runner.SerializeToString())
2739        any_buf = Any()
2740        any_buf.Pack(queue_runner)
2741        ops_lib.add_to_collection("user_defined_any_collection", any_buf)
2742
2743      _, var_list = meta_graph.export_scoped_meta_graph(
2744          filename=os.path.join(test_dir, exported_filename),
2745          graph=ops_lib.get_default_graph(),
2746          export_scope="hidden1")
2747      self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
2748
2749    with graph.as_default(), self.session() as sess:
2750      self.evaluate(variables.global_variables_initializer())
2751      saver = saver_module.Saver(var_list=var_list, max_to_keep=1)
2752      saver.save(sess, os.path.join(test_dir, ckpt_filename), write_state=False)
2753
2754  def _testScopedRestore(self, test_dir, exported_filename,
2755                         new_exported_filename, ckpt_filename):
2756    graph = ops_lib.Graph()
2757    # Create all the missing inputs.
2758    with graph.as_default():
2759      new_image = constant_op.constant(
2760          1.2, dtypes.float32, shape=[100, 28], name="images")
2761      var_list = meta_graph.import_scoped_meta_graph(
2762          os.path.join(test_dir, exported_filename),
2763          graph=graph,
2764          input_map={"$unbound_inputs_images": new_image},
2765          import_scope="new_hidden1")
2766      self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
2767      hidden1 = graph.as_graph_element("new_hidden1/Relu:0")
2768      weights1 = graph.as_graph_element("new_hidden1/weights:0")
2769      biases1 = graph.as_graph_element("new_hidden1/biases:0")
2770
2771    with graph.as_default():
2772      # Hidden 2
2773      with ops_lib.name_scope("hidden2"):
2774        weights = variables.VariableV1(
2775            random_ops.truncated_normal(
2776                [128, 32], stddev=1.0 / math.sqrt(float(128))),
2777            name="weights")
2778
2779        # The use of control_flow_ops.while_loop here is purely for adding test
2780        # coverage the save and restore of control flow context (which doesn't
2781        # make any sense here from a machine learning perspective).  The typical
2782        # biases is a simple Variable without the conditions.
2783        def loop_cond(it, _):
2784          return it < 2
2785
2786        def loop_body(it, biases):
2787          biases += constant_op.constant(0.1, shape=[32])
2788          return it + 1, biases
2789
2790        _, biases = control_flow_ops.while_loop(loop_cond, loop_body, [
2791            constant_op.constant(0), variables.VariableV1(array_ops.zeros([32]))
2792        ])
2793        hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
2794      # Linear
2795      with ops_lib.name_scope("softmax_linear"):
2796        weights = variables.VariableV1(
2797            random_ops.truncated_normal(
2798                [32, 10], stddev=1.0 / math.sqrt(float(32))),
2799            name="weights")
2800        biases = variables.VariableV1(array_ops.zeros([10]), name="biases")
2801        logits = math_ops.matmul(hidden2, weights) + biases
2802        ops_lib.add_to_collection("logits", logits)
2803
2804      # The rest of the variables.
2805      rest_variables = list(
2806          set(variables.global_variables()) - set(var_list.keys()))
2807      init_rest_op = variables.variables_initializer(rest_variables)
2808
2809    with graph.as_default(), self.session() as sess:
2810      saver = saver_module.Saver(var_list=var_list, max_to_keep=1)
2811      saver.restore(sess, os.path.join(test_dir, ckpt_filename))
2812      # Verify that we have restored weights1 and biases1.
2813      self.evaluate([weights1, biases1])
2814      # Initialize the rest of the variables and run logits.
2815      self.evaluate(init_rest_op)
2816      self.evaluate(logits)
2817
2818  # Verifies that we can save the subgraph under "hidden1" and restore it
2819  # into "new_hidden1" in the new graph.
2820  def testScopedSaveAndRestore(self):
2821    test_dir = self._get_test_dir("scoped_export_import")
2822    ckpt_filename = "ckpt"
2823    self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename)
2824    self._testScopedRestore(test_dir, "exported_hidden1.pbtxt",
2825                            "exported_new_hidden1.pbtxt", ckpt_filename)
2826
2827  # Verifies that we can copy the subgraph under "hidden1" and copy it
2828  # to different name scope in the same graph or different graph.
2829  def testCopyScopedGraph(self):
2830    test_dir = self._get_test_dir("scoped_copy")
2831    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
2832    graph1 = ops_lib.Graph()
2833    with graph1.as_default():
2834      with ops_lib.name_scope("hidden1"):
2835        images = constant_op.constant(
2836            1.0, dtypes.float32, shape=[3, 2], name="images")
2837        weights1 = variables.VariableV1(
2838            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
2839        biases1 = variables.VariableV1([0.1] * 3, name="biases")
2840        nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
2841
2842    # Run the graph and save scoped checkpoint.
2843    with graph1.as_default(), self.session(graph=graph1) as sess:
2844      self.evaluate(variables.global_variables_initializer())
2845      _, var_list_1 = meta_graph.export_scoped_meta_graph(
2846          export_scope="hidden1")
2847      saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
2848      saver.save(sess, saver0_ckpt, write_state=False)
2849
2850    expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3))
2851
2852    # Verifies copy to the same graph with the same name fails.
2853    with graph1.as_default():
2854      with self.assertRaisesWithPredicateMatch(
2855          ValueError, lambda e: "need to be different" in str(e)):
2856        meta_graph.copy_scoped_meta_graph(
2857            from_scope="hidden1", to_scope="hidden1")
2858
2859    # Verifies copy to the same graph.
2860    with graph1.as_default():
2861      var_list_2 = meta_graph.copy_scoped_meta_graph(
2862          from_scope="hidden1", to_scope="hidden2")
2863
2864    with graph1.as_default(), self.session(graph=graph1) as sess:
2865      saver1 = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
2866      saver1.restore(sess, saver0_ckpt)
2867      saver2 = saver_module.Saver(var_list=var_list_2, max_to_keep=1)
2868      saver2.restore(sess, saver0_ckpt)
2869      self.assertAllClose(expected, sess.run("hidden1/relu:0"))
2870      self.assertAllClose(expected, sess.run("hidden2/relu:0"))
2871
2872    # Verifies copy to different graph.
2873    graph2 = ops_lib.Graph()
2874    with graph2.as_default():
2875      new_var_list_1 = meta_graph.copy_scoped_meta_graph(
2876          from_scope="hidden1",
2877          to_scope="new_hidden1",
2878          from_graph=graph1,
2879          to_graph=graph2)
2880
2881      with self.session() as sess:
2882        saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)
2883        saver3.restore(sess, saver0_ckpt)
2884        self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
2885
2886  def testExportGraphDefWithScope(self):
2887    test_dir = self._get_test_dir("export_graph_def")
2888    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
2889    graph1 = ops_lib.Graph()
2890    with graph1.as_default():
2891      with ops_lib.name_scope("hidden1"):
2892        images = constant_op.constant(
2893            1.0, dtypes.float32, shape=[3, 2], name="images")
2894        weights1 = variables.VariableV1(
2895            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
2896        biases1 = variables.VariableV1([0.1] * 3, name="biases")
2897        nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
2898
2899      # Run the graph and save scoped checkpoint.
2900      with self.session(graph=graph1) as sess:
2901        self.evaluate(variables.global_variables_initializer())
2902        _, var_list_1 = meta_graph.export_scoped_meta_graph(
2903            graph_def=graph1.as_graph_def(), export_scope="hidden1")
2904        saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
2905        saver.save(sess, saver0_ckpt, write_state=False)
2906
2907    expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3))
2908
2909    # Verifies that we can run successfully after restoring.
2910    graph2 = ops_lib.Graph()
2911    with graph2.as_default():
2912      new_var_list_1 = meta_graph.copy_scoped_meta_graph(
2913          from_scope="hidden1",
2914          to_scope="new_hidden1",
2915          from_graph=graph1,
2916          to_graph=graph2)
2917
2918      with self.session(graph=graph2) as sess:
2919        saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)
2920        saver3.restore(sess, saver0_ckpt)
2921        self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
2922
2923  def testSerializeSaverWithScope(self):
2924    test_dir = self._get_test_dir("export_graph_def")
2925    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
2926    saver2_ckpt = os.path.join(test_dir, "saver2.ckpt")
2927    graph = ops_lib.Graph()
2928    with graph.as_default():
2929      with ops_lib.name_scope("hidden1"):
2930        variable1 = variables.VariableV1([1.0], name="variable1")
2931        saver1 = saver_module.Saver(var_list=[variable1])
2932        graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver1)
2933
2934      with ops_lib.name_scope("hidden2"):
2935        variable2 = variables.VariableV1([2.0], name="variable2")
2936      saver2 = saver_module.Saver(var_list=[variable2], name="hidden2/")
2937      graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver2)
2938
2939      with self.session(graph=graph) as sess:
2940        self.evaluate(variables.global_variables_initializer())
2941        saver1.save(sess, saver1_ckpt, write_state=False)
2942        saver2.save(sess, saver2_ckpt, write_state=False)
2943
2944    graph1 = ops_lib.Graph()
2945    with graph1.as_default():
2946      var_dict1 = meta_graph.copy_scoped_meta_graph(
2947          from_scope="hidden1",
2948          to_scope="new_hidden1",
2949          from_graph=graph,
2950          to_graph=graph1)
2951      self.assertEqual(1, len(var_dict1))
2952
2953      saver_list1 = graph1.get_collection(ops_lib.GraphKeys.SAVERS)
2954      self.assertEqual(1, len(saver_list1))
2955
2956      with self.session(graph=graph1) as sess:
2957        saver_list1[0].restore(sess, saver1_ckpt)
2958        self.assertEqual(1.0, self.evaluate(var_dict1["variable1:0"]))
2959
2960    graph2 = ops_lib.Graph()
2961    with graph2.as_default():
2962      var_dict2 = meta_graph.copy_scoped_meta_graph(
2963          from_scope="hidden2",
2964          to_scope="new_hidden2",
2965          from_graph=graph,
2966          to_graph=graph2)
2967      self.assertEqual(1, len(var_dict2))
2968
2969      saver_list2 = graph2.get_collection(ops_lib.GraphKeys.SAVERS)
2970      self.assertEqual(1, len(saver_list2))
2971
2972      with self.session(graph=graph2) as sess:
2973        saver_list2[0].restore(sess, saver2_ckpt)
2974        self.assertEqual(2.0, self.evaluate(var_dict2["variable2:0"]))
2975
2976
2977class _OwnsAVariableSimple(trackable_base.Trackable):
2978  """A Trackable object which can be saved using a tf.train.Saver."""
2979
2980  def __init__(self):
2981    self.non_dep_variable = variable_scope.get_variable(
2982        name="non_dep_variable", initializer=6., use_resource=True)
2983
2984  def _gather_saveables_for_checkpoint(self):
2985    return {trackable_base.VARIABLE_VALUE_KEY: self.non_dep_variable}
2986
2987  # The Saver sorts by name before parsing, so we need a name property.
2988  @property
2989  def name(self):
2990    return self.non_dep_variable.name
2991
2992
2993class _MirroringSaveable(
2994    saver_module.BaseSaverBuilder.ResourceVariableSaveable):
2995
2996  def __init__(self, primary_variable, mirrored_variable, name):
2997    self._primary_variable = primary_variable
2998    self._mirrored_variable = mirrored_variable
2999    super(_MirroringSaveable, self).__init__(
3000        self._primary_variable, "", name)
3001
3002  def restore(self, restored_tensors, restored_shapes):
3003    """Restore the same value into both variables."""
3004    tensor, = restored_tensors
3005    return control_flow_ops.group(
3006        self._primary_variable.assign(tensor),
3007        self._mirrored_variable.assign(tensor))
3008
3009
3010class _OwnsMirroredVariables(trackable_base.Trackable):
3011  """A Trackable object which returns a more complex SaveableObject."""
3012
3013  def __init__(self):
3014    self.non_dep_variable = variable_scope.get_variable(
3015        name="non_dep_variable", initializer=6., use_resource=True)
3016    self.mirrored = variable_scope.get_variable(
3017        name="mirrored", initializer=15., use_resource=True)
3018
3019  def _gather_saveables_for_checkpoint(self):
3020    def _saveable_factory(name=self.non_dep_variable.name):
3021      return _MirroringSaveable(
3022          primary_variable=self.non_dep_variable,
3023          mirrored_variable=self.mirrored,
3024          name=name)
3025    return {trackable_base.VARIABLE_VALUE_KEY: _saveable_factory}
3026
3027  # The Saver sorts by name before parsing, so we need a name property.
3028  @property
3029  def name(self):
3030    return self.non_dep_variable.name
3031
3032
3033class TrackableCompatibilityTests(test.TestCase):
3034
3035  # TODO(allenl): Track down python3 reference cycles in these tests.
3036  @test_util.run_in_graph_and_eager_modes
3037  def testNotSaveableButIsTrackable(self):
3038    v = _OwnsAVariableSimple()
3039    test_dir = self.get_temp_dir()
3040    prefix = os.path.join(test_dir, "ckpt")
3041    for saver in (saver_module.Saver(var_list=[v]),
3042                  saver_module.Saver(var_list={"v": v})):
3043      with self.cached_session() as sess:
3044        self.evaluate(v.non_dep_variable.assign(42.))
3045        save_path = saver.save(sess, prefix)
3046        self.evaluate(v.non_dep_variable.assign(43.))
3047        saver.restore(sess, save_path)
3048        self.assertEqual(42., self.evaluate(v.non_dep_variable))
3049
3050  @test_util.run_in_graph_and_eager_modes
3051  def testMoreComplexSaveableReturned(self):
3052    v = _OwnsMirroredVariables()
3053    test_dir = self.get_temp_dir()
3054    prefix = os.path.join(test_dir, "ckpt")
3055    self.evaluate(v.non_dep_variable.assign(42.))
3056    for saver in (saver_module.Saver(var_list=[v]),
3057                  saver_module.Saver(var_list={"v": v})):
3058      with self.cached_session() as sess:
3059        save_path = saver.save(sess, prefix)
3060        self.evaluate(v.non_dep_variable.assign(43.))
3061        self.evaluate(v.mirrored.assign(44.))
3062        saver.restore(sess, save_path)
3063        self.assertEqual(42., self.evaluate(v.non_dep_variable))
3064        self.assertEqual(42., self.evaluate(v.mirrored))
3065
3066  def testSingleTensorEvaluation(self):
3067
3068    class _CountingSaveable(saver_module.BaseSaverBuilder.SaveableObject):
3069
3070      def __init__(self, name):
3071        self.eval_count = 0
3072        def _tensor():
3073          self.eval_count += 1
3074          return constant_op.constant([1.])
3075        dummy_op = constant_op.constant([2.])
3076        super(_CountingSaveable, self).__init__(
3077            dummy_op,
3078            [saver_module.BaseSaverBuilder.SaveSpec(
3079                _tensor, "", name, dtype=dummy_op.dtype,
3080                device=dummy_op.device)],
3081            name)
3082
3083      def restore(self, restored_tensors, restored_shapes):
3084        """Restore the same value into both variables."""
3085        pass
3086
3087    with context.eager_mode():
3088      v = _CountingSaveable("foo")
3089      saver = saver_module.Saver(var_list=[v])
3090      test_dir = self.get_temp_dir()
3091      prefix = os.path.join(test_dir, "ckpt")
3092      with self.cached_session() as sess:
3093        save_path = saver.save(sess, prefix)
3094        self.assertEqual(1, v.eval_count)
3095        saver.restore(sess, save_path)
3096        self.assertEqual(1, v.eval_count)
3097
3098  def testVariableNotFoundErrorRaised(self):
3099    # Restore does some tricky exception handling to figure out if it should
3100    # load an object-based checkpoint. Tests that the exception handling isn't
3101    # too broad.
3102    checkpoint_directory = self.get_temp_dir()
3103    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
3104
3105    a = resource_variable_ops.ResourceVariable(1., name="a")
3106    b = resource_variable_ops.ResourceVariable(1., name="b")
3107    a_saver = saver_module.Saver([a])
3108    b_saver = saver_module.Saver([b])
3109    with self.cached_session() as sess:
3110      self.evaluate(a.initializer)
3111      save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
3112      with self.assertRaisesRegex(errors.NotFoundError,
3113                                  "Key b not found in checkpoint"):
3114        b_saver.restore(sess=sess, save_path=save_path)
3115
3116      with self.assertRaises(errors.NotFoundError) as cs:
3117        b_saver.restore(sess=sess, save_path=save_path)
3118
3119      # Make sure we don't have a confusing "During handling of the above
3120      # exception" block in Python 3.
3121      self.assertNotIn("NewCheckpointReader", cs.exception.message)
3122
3123  @test_util.run_v1_only("train.Saver is V1 only API.")
3124  def testGraphChangedForRestoreErrorRaised(self):
3125    checkpoint_directory = self.get_temp_dir()
3126    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
3127
3128    with ops_lib.Graph().as_default() as g:
3129      a = variables.VariableV1(1., name="a")
3130      a_saver = saver_module.Saver([a])
3131
3132      with self.session(graph=g) as sess:
3133        self.evaluate(a.initializer)
3134        save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
3135
3136    with ops_lib.Graph().as_default() as g:
3137      a = variables.VariableV1([1.], name="a")
3138      a_saver = saver_module.Saver([a])
3139      with self.session(graph=g) as sess:
3140        with self.assertRaisesRegex(
3141            errors.InvalidArgumentError,
3142            "a mismatch between the current graph and the graph"):
3143          a_saver.restore(sess=sess, save_path=save_path)
3144
3145
3146if __name__ == "__main__":
3147  test.main()
3148