1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for tensorflow.python.training.saver.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import glob 22import math 23import os 24import random 25import time 26 27import numpy as np 28import six 29 30from google.protobuf.any_pb2 import Any 31 32from tensorflow.core.protobuf import config_pb2 33from tensorflow.core.protobuf import meta_graph_pb2 34from tensorflow.core.protobuf import queue_runner_pb2 35from tensorflow.core.protobuf import rewriter_config_pb2 36from tensorflow.core.protobuf import saver_pb2 37from tensorflow.python.client import session 38from tensorflow.python.data.ops import dataset_ops 39from tensorflow.python.data.ops import iterator_ops 40from tensorflow.python.eager import context 41from tensorflow.python.framework import constant_op 42from tensorflow.python.framework import dtypes 43from tensorflow.python.framework import errors 44from tensorflow.python.framework import errors_impl 45from tensorflow.python.framework import function 46from tensorflow.python.framework import graph_io 47from tensorflow.python.framework import meta_graph 48from tensorflow.python.framework import ops as ops_lib 49from tensorflow.python.framework import test_util 50from tensorflow.python.lib.io import file_io 51from tensorflow.python.ops import array_ops 52from tensorflow.python.ops import control_flow_ops 53from tensorflow.python.ops import data_flow_ops 54from tensorflow.python.ops import gradients_impl 55from tensorflow.python.ops import math_ops 56from tensorflow.python.ops import nn_ops 57from tensorflow.python.ops import partitioned_variables 58from tensorflow.python.ops import random_ops 59from tensorflow.python.ops import resource_variable_ops 60from tensorflow.python.ops import sparse_ops 61from tensorflow.python.ops import variable_scope 62from tensorflow.python.ops import variables 63import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 64from tensorflow.python.platform import gfile 65from tensorflow.python.platform import test 66from tensorflow.python.summary import summary 67from tensorflow.python.training import adam 68from tensorflow.python.training import checkpoint_management 69from tensorflow.python.training import gradient_descent 70from tensorflow.python.training import py_checkpoint_reader 71from tensorflow.python.training import queue_runner_impl 72from tensorflow.python.training import saver as saver_module 73from tensorflow.python.training import saver_test_utils 74from tensorflow.python.training.tracking import base as trackable_base 75from tensorflow.python.util import compat 76 77 78class SaverTest(test.TestCase): 79 80 def basicSaveRestore(self, variable_op): 81 save_path = os.path.join(self.get_temp_dir(), "basic_save_restore") 82 83 with self.session(graph=ops_lib.Graph()) as sess: 84 # Build a graph with 2 parameter nodes, and Save and 85 # Restore nodes for them. 86 v0 = variable_op(10.0, name="v0") 87 v1 = variable_op(20.0, name="v1") 88 v2 = saver_test_utils.CheckpointedOp(name="v2") 89 v2_init = v2.insert("k1", 30.0) 90 91 # Initialize all variables 92 if not context.executing_eagerly(): 93 self.evaluate([variables.global_variables_initializer(), v2_init]) 94 95 # Check that the parameter nodes have been initialized. 96 self.assertEqual(10.0, self.evaluate(v0)) 97 self.assertEqual(20.0, self.evaluate(v1)) 98 self.assertEqual(b"k1", self.evaluate(v2.keys())) 99 self.assertEqual(30.0, self.evaluate(v2.values())) 100 101 # Save the initialized values in the file at "save_path" 102 save = saver_module.Saver( 103 { 104 "v0": v0, 105 "v1": v1, 106 "v2": v2.saveable 107 }, restore_sequentially=True) 108 val = save.save(sess, save_path) 109 self.assertTrue(isinstance(val, six.string_types)) 110 self.assertEqual(save_path, val) 111 112 # Start a second session. In that session the parameter nodes 113 # have not been initialized either. 114 with self.session(graph=ops_lib.Graph()) as sess: 115 v0 = variable_op(-1.0, name="v0") 116 v1 = variable_op(-1.0, name="v1") 117 v2 = saver_test_utils.CheckpointedOp(name="v2") 118 119 # Assert that the variables are not initialized. 120 if not context.executing_eagerly(): 121 self.assertEqual( 122 len(variables.report_uninitialized_variables().eval()), 2) 123 self.assertEqual(0, len(self.evaluate(v2.keys()))) 124 self.assertEqual(0, len(self.evaluate(v2.values()))) 125 # Restore the saved values in the parameter nodes. 126 save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable}) 127 save.restore(sess, save_path) 128 # Check that the parameter nodes have been restored. 129 self.assertEqual(10.0, self.evaluate(v0)) 130 self.assertEqual(20.0, self.evaluate(v1)) 131 self.assertEqual(b"k1", self.evaluate(v2.keys())) 132 self.assertEqual(30.0, self.evaluate(v2.values())) 133 134 # Build another graph with 2 nodes, initialized 135 # differently, and a Restore node for them. 136 with self.session(graph=ops_lib.Graph()) as sess: 137 v0_2 = variable_op(1000.0, name="v0") 138 v1_2 = variable_op(2000.0, name="v1") 139 v2_2 = saver_test_utils.CheckpointedOp(name="v2") 140 v2_init = v2_2.insert("k1000", 3000.0) 141 142 # Check that the parameter nodes have been initialized. 143 if not context.executing_eagerly(): 144 init_all_op = [variables.global_variables_initializer(), v2_init] 145 self.evaluate(init_all_op) 146 # TODO(xpan): Why _mutable_hash_table_v2 doesn't create empty 147 # table as it claims in eager mode? 148 self.assertEqual(b"k1000", self.evaluate(v2_2.keys())) 149 self.assertEqual(3000.0, self.evaluate(v2_2.values())) 150 self.assertEqual(1000.0, self.evaluate(v0_2)) 151 self.assertEqual(2000.0, self.evaluate(v1_2)) 152 153 # Restore the values saved earlier in the parameter nodes. 154 save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable}) 155 save2.restore(sess, save_path) 156 # Check that the parameter nodes have been restored. 157 self.assertEqual(10.0, self.evaluate(v0_2)) 158 self.assertEqual(20.0, self.evaluate(v1_2)) 159 self.assertEqual(b"k1", self.evaluate(v2_2.keys())) 160 self.assertEqual(30.0, self.evaluate(v2_2.values())) 161 162 def testBasic(self): 163 self.basicSaveRestore(variables.Variable) 164 165 @test_util.run_in_graph_and_eager_modes 166 def testResourceBasic(self): 167 self.basicSaveRestore(resource_variable_ops.ResourceVariable) 168 169 def testResourceColocation(self): 170 # train.Saver is V1 only API. 171 with ops_lib.Graph().as_default(): 172 partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2) 173 with ops_lib.device("/job:ps/device:GPU:0"): 174 v = variable_scope.get_variable( 175 "v0", shape=[10, 2], partitioner=partitioner, use_resource=True) 176 saver_module.Saver({"v0": v}).build() 177 save_op = None 178 for op in ops_lib.get_default_graph().get_operations(): 179 if op.type == "SaveV2": 180 save_op = op 181 break 182 assert save_op is not None 183 for save_inp in save_op.inputs[3:]: 184 # Input to SaveV2 op is placed on CPU of the same device as 185 # the Variable. 186 self.assertEqual("/job:ps/device:CPU:0", save_inp.device) 187 188 def testResourceVariableReadOpsAddedDeterministically(self): 189 graph_defs = [] 190 num_graphs = 10 191 for _ in range(num_graphs): 192 with ops_lib.Graph().as_default() as g: 193 for i in range(20): 194 resource_variable_ops.ResourceVariable(i, name="var%s" % i) 195 saver_module.Saver() 196 graph_defs.append(g.as_graph_def()) 197 for i in range(num_graphs - 1): 198 self.assertEqual(graph_defs[i], graph_defs[i + 1]) 199 200 def testEagerBasic(self): 201 with context.eager_mode(): 202 ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt") 203 204 v1 = resource_variable_ops.ResourceVariable(3.14, name="v1") 205 v2 = resource_variable_ops.ResourceVariable([1, 2], name="v2") 206 save = saver_module.Saver([v1, v2]) 207 save.save(None, ckpt_prefix) 208 209 v1.assign(0.0) 210 v2.assign([0, 0]) 211 self.assertNear(0.0, self.evaluate(v1), 1e-5) 212 self.assertAllEqual([0, 0], self.evaluate(v2)) 213 214 save.restore(None, ckpt_prefix) 215 self.assertNear(3.14, self.evaluate(v1), 1e-5) 216 self.assertAllEqual([1, 2], self.evaluate(v2)) 217 218 def testEagerGraphCompatibility(self): 219 # Save from graph mode and restore from eager mode. 220 graph_ckpt_prefix = os.path.join(self.get_temp_dir(), "graph_ckpt") 221 with context.graph_mode(): 222 with self.session(graph=ops_lib.Graph()) as sess: 223 # Create a graph model and save the checkpoint. 224 w1 = resource_variable_ops.ResourceVariable(1.0, name="w1") 225 w2 = resource_variable_ops.ResourceVariable(2.0, name="w2") 226 graph_saver = saver_module.Saver([w1, w2]) 227 self.evaluate(variables.global_variables_initializer()) 228 graph_saver.save(sess, graph_ckpt_prefix) 229 230 with context.eager_mode(): 231 ops_lib._default_graph_stack.reset() # pylint: disable=protected-access 232 ops_lib.reset_default_graph() 233 234 w1 = resource_variable_ops.ResourceVariable(0.0, name="w1") 235 w2 = resource_variable_ops.ResourceVariable(0.0, name="w2") 236 237 graph_saver = saver_module.Saver([w1, w2]) 238 graph_saver.restore(None, graph_ckpt_prefix) 239 240 self.assertAllEqual(self.evaluate(w1), 1.0) 241 self.assertAllEqual(self.evaluate(w2), 2.0) 242 243 # Save from eager mode and restore from graph mode. 244 eager_ckpt_prefix = os.path.join(self.get_temp_dir(), "eager_ckpt") 245 with context.eager_mode(): 246 ops_lib._default_graph_stack.reset() # pylint: disable=protected-access 247 ops_lib.reset_default_graph() 248 249 w3 = resource_variable_ops.ResourceVariable(3.0, name="w3") 250 w4 = resource_variable_ops.ResourceVariable(4.0, name="w4") 251 252 graph_saver = saver_module.Saver([w3, w4]) 253 graph_saver.save(None, eager_ckpt_prefix) 254 255 with context.graph_mode(): 256 with self.session(graph=ops_lib.Graph()) as sess: 257 w3 = resource_variable_ops.ResourceVariable(0.0, name="w3") 258 w4 = resource_variable_ops.ResourceVariable(0.0, name="w4") 259 graph_saver = saver_module.Saver([w3, w4]) 260 self.evaluate(variables.global_variables_initializer()) 261 graph_saver.restore(sess, eager_ckpt_prefix) 262 self.assertAllEqual(w3, 3.0) 263 self.assertAllEqual(w4, 4.0) 264 265 @test_util.run_in_graph_and_eager_modes 266 def testResourceSaveRestoreCachingDevice(self): 267 save_path = os.path.join(self.get_temp_dir(), "resource_cache") 268 with self.session(graph=ops_lib.Graph()) as sess: 269 v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0", 270 name="v") 271 if context.executing_eagerly(): 272 sess = None 273 else: 274 self.evaluate(variables.global_variables_initializer()) 275 save = saver_module.Saver([v]) 276 save.save(sess, save_path) 277 278 save2 = saver_module.Saver([v]) 279 save2.restore(sess, save_path) 280 self.assertEqual(self.evaluate(v), [1]) 281 282 def testNoAdditionalOpsAddedBySaverForResourceVariablesOutsideSaveScope(self): 283 with ops_lib.Graph().as_default() as g: 284 v = resource_variable_ops.ResourceVariable(1.0, name="v") 285 with ops_lib.name_scope("saver1"): 286 saver_module.Saver() 287 with ops_lib.name_scope("saver2"): 288 saver_module.Saver({"name": v}) 289 ops_in_saver1_scope_but_not_save_scope = [ 290 op for op in g.get_operations() 291 if (op.name.startswith("saver1/") and 292 not op.name.startswith("saver1/save/"))] 293 self.assertEqual(ops_in_saver1_scope_but_not_save_scope, []) 294 ops_in_saver2_scope_but_not_save_scope = [ 295 op for op in g.get_operations() 296 if (op.name.startswith("saver2/") and 297 not op.name.startswith("saver2/save/"))] 298 self.assertEqual(ops_in_saver2_scope_but_not_save_scope, []) 299 300 def testSaveCopyRestoreWithSaveRelativePaths(self): 301 """Save, copy checkpoint dir and restore from copied dir. 302 303 This only works for save_relative_paths=True. 304 """ 305 save_dir1 = os.path.join(self.get_temp_dir(), "save_dir1") 306 os.mkdir(save_dir1) 307 save_path1 = os.path.join(save_dir1, "save_copy_restore") 308 309 # train.Saver is V1 only API. 310 with ops_lib.Graph().as_default(): 311 # Build a graph with 2 parameter nodes, and Save and 312 # Restore nodes for them. 313 v0 = variables.VariableV1(10.0, name="v0") 314 v1 = variables.VariableV1(20.0, name="v1") 315 v2 = saver_test_utils.CheckpointedOp(name="v2") 316 v2_init = v2.insert("k1", 30.0) 317 save = saver_module.Saver( 318 var_list={ 319 "v0": v0, 320 "v1": v1, 321 "v2": v2.saveable 322 }, 323 restore_sequentially=True, 324 save_relative_paths=True) 325 init_all_op = [variables.global_variables_initializer(), v2_init] 326 327 with self.cached_session() as sess: 328 # Initialize all variables 329 self.evaluate(init_all_op) 330 331 # Check that the parameter nodes have been initialized. 332 self.assertEqual(10.0, self.evaluate(v0)) 333 self.assertEqual(20.0, self.evaluate(v1)) 334 self.assertEqual(b"k1", self.evaluate(v2.keys())) 335 self.assertEqual(30.0, self.evaluate(v2.values())) 336 337 # Save the initialized values in the file at "save_path" 338 val = save.save(sess, save_path1) 339 self.assertTrue(isinstance(val, six.string_types)) 340 self.assertEqual(save_path1, val) 341 342 self.assertEqual( 343 checkpoint_management.latest_checkpoint(save_dir1), save_path1) 344 save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2") 345 os.renames(save_dir1, save_dir2) 346 save_path2 = os.path.join(save_dir2, "save_copy_restore") 347 self.assertEqual( 348 checkpoint_management.latest_checkpoint(save_dir2), save_path2) 349 350 # Start a second session. In that session the parameter nodes 351 # have not been initialized either. 352 with self.cached_session() as sess: 353 v0 = variables.VariableV1(-1.0, name="v0") 354 v1 = variables.VariableV1(-1.0, name="v1") 355 v2 = saver_test_utils.CheckpointedOp(name="v2") 356 save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable}) 357 358 # Assert that the variables are not initialized. 359 self.assertEqual( 360 len(variables.report_uninitialized_variables().eval()), 2) 361 self.assertEqual(0, len(self.evaluate(v2.keys()))) 362 self.assertEqual(0, len(self.evaluate(v2.values()))) 363 364 # Restore the saved values in the parameter nodes. 365 save.restore(sess, save_path2) 366 # Check that the parameter nodes have been restored. 367 self.assertEqual(10.0, self.evaluate(v0)) 368 self.assertEqual(20.0, self.evaluate(v1)) 369 self.assertEqual(b"k1", self.evaluate(v2.keys())) 370 self.assertEqual(30.0, self.evaluate(v2.values())) 371 372 def testFilenameTensor(self): 373 # train.Saver is V1 only API. 374 with ops_lib.Graph().as_default(): 375 v0 = variables.VariableV1(0, name="v0") 376 filename = b"somerandomfilename" 377 save = saver_module.Saver({"v0": v0}, filename=filename) 378 with self.cached_session() as sess: 379 tensor = sess.graph.get_tensor_by_name( 380 save.saver_def.filename_tensor_name) 381 self.assertEqual(self.evaluate(tensor), filename) 382 383 def testInvalidPath(self): 384 v0 = variables.VariableV1(0, name="v0") 385 for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2): 386 with self.cached_session() as sess: 387 save = saver_module.Saver({"v0": v0}, write_version=ver) 388 with self.assertRaisesRegex( 389 ValueError, "The passed save_path is not a valid checkpoint:"): 390 save.restore(sess, "invalid path") 391 392 @test_util.run_v1_only("train.Saver is V1 only API.") 393 def testInt64(self): 394 save_path = os.path.join(self.get_temp_dir(), "int64") 395 396 with self.cached_session() as sess: 397 # Build a graph with 1 node, and save and restore for them. 398 v = variables.VariableV1(np.int64(15), name="v") 399 save = saver_module.Saver({"v": v}, restore_sequentially=True) 400 self.evaluate(variables.global_variables_initializer()) 401 402 # Save the initialized values in the file at "save_path" 403 val = save.save(sess, save_path) 404 self.assertTrue(isinstance(val, six.string_types)) 405 self.assertEqual(save_path, val) 406 407 with self.cached_session() as sess: 408 v = variables.VariableV1(np.int64(-1), name="v") 409 save = saver_module.Saver({"v": v}) 410 411 with self.assertRaisesWithPredicateMatch( 412 errors_impl.OpError, lambda e: "uninitialized value v" in e.message): 413 self.evaluate(v) 414 415 # Restore the saved values in the parameter nodes. 416 save.restore(sess, save_path) 417 # Check that the parameter nodes have been restored. 418 self.assertEqual(np.int64(15), self.evaluate(v)) 419 420 def testSomeErrors(self): 421 with ops_lib.Graph().as_default(): 422 v0 = variables.VariableV1([10.0], name="v0") 423 v1 = variables.VariableV1([20.0], name="v1") 424 v2 = variables.VariableV1([20.0], name="v2") 425 v2._set_save_slice_info( 426 variables.Variable.SaveSliceInfo("v1", [1], [0], [1])) 427 428 # By default the name used for "v2" will be "v1" and raise an error. 429 with self.assertRaisesRegex(ValueError, "same name: v1"): 430 saver_module.Saver([v0, v1, v2]) 431 432 # The names are different and will work. 433 saver_module.Saver({"vee1": v1, "other": [v2]}) 434 435 # Partitioned variables also cause name conflicts. 436 p_v1 = variable_scope.get_variable( 437 "p_v1", 438 shape=[4, 5], 439 partitioner=partitioned_variables.fixed_size_partitioner( 440 num_shards=2)) 441 p_v2 = variable_scope.get_variable( 442 "p_v2", 443 shape=[4, 5], 444 partitioner=partitioned_variables.fixed_size_partitioner( 445 num_shards=2)) 446 p_v2._name = "p_v1" 447 with self.assertRaisesRegex(ValueError, "same name: p_v1"): 448 saver_module.Saver([p_v1, p_v2]) 449 450 def testSameName(self): 451 with ops_lib.Graph().as_default(): 452 v0 = variables.VariableV1([10.0], name="v0") 453 v2 = saver_test_utils.CheckpointedOp(name="v2") 454 455 # Saving one variable under two names raises an error. 456 with self.assertRaisesRegex( 457 ValueError, "The same saveable will be restored with two names: v0"): 458 saver_module.Saver({"v0": v0, "v0too": v0}) 459 460 # Ditto for custom saveables. 461 with self.assertRaisesRegex( 462 ValueError, "The same saveable will be restored with two names: v2"): 463 saver_module.Saver({"v2": v2.saveable, "v2too": v2.saveable}) 464 465 # Verify non-duplicate names work. 466 saver_module.Saver({"v0": v0, "v2": v2.saveable}) 467 468 @test_util.run_v1_only("train.Saver and VariableV1 are V1 only APIs.") 469 def testBasicsWithListOfVariables(self): 470 save_path = os.path.join(self.get_temp_dir(), "basics_with_list") 471 472 with self.session(graph=ops_lib.Graph()) as sess: 473 # Build a graph with 2 parameter nodes, and Save and 474 # Restore nodes for them. 475 v0 = variables.VariableV1(10.0, name="v0") 476 v1 = variables.VariableV1(20.0, name="v1") 477 v2 = saver_test_utils.CheckpointedOp(name="v2") 478 v2_init = v2.insert("k1", 30.0) 479 save = saver_module.Saver([v0, v1, v2.saveable]) 480 self.evaluate(variables.global_variables_initializer()) 481 v2_init.run() 482 483 # Check that the parameter nodes have been initialized. 484 self.assertEqual(10.0, self.evaluate(v0)) 485 self.assertEqual(20.0, self.evaluate(v1)) 486 self.assertEqual(b"k1", self.evaluate(v2.keys())) 487 self.assertEqual(30.0, self.evaluate(v2.values())) 488 489 # Save the initialized values in the file at "save_path" 490 val = save.save(sess, save_path) 491 self.assertTrue(isinstance(val, six.string_types)) 492 self.assertEqual(save_path, val) 493 494 # Start a second session. In that session the variables 495 # have not been initialized either. 496 with self.session(graph=ops_lib.Graph()) as sess: 497 v0 = variables.VariableV1(-1.0, name="v0") 498 v1 = variables.VariableV1(-1.0, name="v1") 499 v2 = saver_test_utils.CheckpointedOp(name="v2") 500 save = saver_module.Saver([v0, v1, v2.saveable]) 501 502 with self.assertRaisesWithPredicateMatch( 503 errors_impl.OpError, lambda e: "uninitialized value v0" in e.message): 504 self.evaluate(v0) 505 with self.assertRaisesWithPredicateMatch( 506 errors_impl.OpError, lambda e: "uninitialized value v1" in e.message): 507 self.evaluate(v1) 508 self.assertEqual(0, len(self.evaluate(v2.keys()))) 509 self.assertEqual(0, len(self.evaluate(v2.values()))) 510 511 # Restore the saved values in the parameter nodes. 512 save.restore(sess, save_path) 513 # Check that the parameter nodes have been restored. 514 self.assertEqual(10.0, self.evaluate(v0)) 515 self.assertEqual(20.0, self.evaluate(v1)) 516 self.assertEqual(b"k1", self.evaluate(v2.keys())) 517 self.assertEqual(30.0, self.evaluate(v2.values())) 518 519 # Build another graph with 2 nodes, initialized 520 # differently, and a Restore node for them. 521 with self.session(graph=ops_lib.Graph()) as sess: 522 v0_2 = variables.VariableV1(1000.0, name="v0") 523 v1_2 = variables.VariableV1(2000.0, name="v1") 524 v2_2 = saver_test_utils.CheckpointedOp(name="v2") 525 save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable]) 526 v2_2.insert("k1000", 3000.0).run() 527 self.evaluate(variables.global_variables_initializer()) 528 529 # Check that the parameter nodes have been initialized. 530 self.assertEqual(1000.0, self.evaluate(v0_2)) 531 self.assertEqual(2000.0, self.evaluate(v1_2)) 532 self.assertEqual(b"k1000", self.evaluate(v2_2.keys())) 533 self.assertEqual(3000.0, self.evaluate(v2_2.values())) 534 # Restore the values saved earlier in the parameter nodes. 535 save2.restore(sess, save_path) 536 # Check that the parameter nodes have been restored. 537 self.assertEqual(10.0, self.evaluate(v0_2)) 538 self.assertEqual(20.0, self.evaluate(v1_2)) 539 self.assertEqual(b"k1", self.evaluate(v2_2.keys())) 540 self.assertEqual(30.0, self.evaluate(v2_2.values())) 541 542 def _SaveAndLoad(self, var_name, var_value, other_value, save_path): 543 with self.session(graph=ops_lib.Graph()) as sess: 544 var = resource_variable_ops.ResourceVariable(var_value, name=var_name) 545 save = saver_module.Saver({var_name: var}) 546 if not context.executing_eagerly(): 547 self.evaluate(var.initializer) 548 val = save.save(sess, save_path) 549 self.assertEqual(save_path, val) 550 with self.session(graph=ops_lib.Graph()) as sess: 551 var = resource_variable_ops.ResourceVariable(other_value, name=var_name) 552 save = saver_module.Saver({var_name: var}) 553 save.restore(sess, save_path) 554 self.assertAllClose(var_value, self.evaluate(var)) 555 556 def testCacheRereadsFile(self): 557 save_path = os.path.join(self.get_temp_dir(), "cache_rereads") 558 # Save and reload one Variable named "var0". 559 self._SaveAndLoad("var0", 0.0, 1.0, save_path) 560 # Save and reload one Variable named "var1" in the same file. 561 # The cached readers should know to re-read the file. 562 self._SaveAndLoad("var1", 1.1, 2.2, save_path) 563 564 def testAllowEmpty(self): 565 save_path = os.path.join(self.get_temp_dir(), "allow_empty") 566 # train.Saver is V1 only API. 567 with ops_lib.Graph().as_default(), self.cached_session() as sess: 568 _ = constant_op.constant(1) 569 save = saver_module.Saver(allow_empty=True) 570 val = save.save(sess, save_path) 571 self.assertIsNone(val) 572 with ops_lib.Graph().as_default(), self.cached_session() as sess: 573 save = saver_module.Saver(allow_empty=True) 574 save.restore(sess, save_path) 575 576 def testGPU(self): 577 if not test.is_gpu_available(): 578 return 579 save_path = os.path.join(self.get_temp_dir(), "gpu") 580 with session.Session("", graph=ops_lib.Graph()) as sess: 581 with sess.graph.device(test.gpu_device_name()): 582 v0_1 = variables.VariableV1(123.45) 583 save = saver_module.Saver({"v0": v0_1}) 584 self.evaluate(variables.global_variables_initializer()) 585 save.save(sess, save_path) 586 587 with session.Session("", graph=ops_lib.Graph()) as sess: 588 with sess.graph.device(test.gpu_device_name()): 589 v0_2 = variables.VariableV1(543.21) 590 save = saver_module.Saver({"v0": v0_2}) 591 self.evaluate(variables.global_variables_initializer()) 592 593 def testSharedServerOnGPU(self): 594 if not test.is_gpu_available(): 595 return 596 save_path = os.path.join(self.get_temp_dir(), "gpu") 597 with session.Session("", graph=ops_lib.Graph()) as sess: 598 with sess.graph.device(test.gpu_device_name()): 599 v0_1 = variables.VariableV1(123.45) 600 save = saver_module.Saver({"v0": v0_1}, sharded=True, allow_empty=True) 601 self.evaluate(variables.global_variables_initializer()) 602 save.save(sess, save_path) 603 604 with session.Session("", graph=ops_lib.Graph()) as sess: 605 with sess.graph.device(test.gpu_device_name()): 606 v0_2 = variables.VariableV1(543.21) 607 save = saver_module.Saver({"v0": v0_2}, sharded=True, allow_empty=True) 608 self.evaluate(variables.global_variables_initializer()) 609 610 def testVariables(self): 611 save_path = os.path.join(self.get_temp_dir(), "variables") 612 with session.Session("", graph=ops_lib.Graph()) as sess: 613 one = variables.VariableV1(1.0) 614 twos = variables.VariableV1([2.0, 2.0, 2.0]) 615 v2 = saver_test_utils.CheckpointedOp(name="v2") 616 init = variables.global_variables_initializer() 617 save = saver_module.Saver() 618 init.run() 619 v2.insert("k1", 3.0).run() 620 save.save(sess, save_path) 621 622 with session.Session("", graph=ops_lib.Graph()) as sess: 623 one = variables.VariableV1(0.0) 624 twos = variables.VariableV1([0.0, 0.0, 0.0]) 625 v2 = saver_test_utils.CheckpointedOp(name="v2") 626 # Saver with no arg, defaults to 'all variables'. 627 save = saver_module.Saver() 628 save.restore(sess, save_path) 629 self.assertAllClose(1.0, self.evaluate(one)) 630 self.assertAllClose([2.0, 2.0, 2.0], self.evaluate(twos)) 631 self.assertEqual(b"k1", self.evaluate(v2.keys())) 632 self.assertEqual(3.0, self.evaluate(v2.values())) 633 634 def testVarListShouldBeEmptyInDeferredBuild(self): 635 with ops_lib.Graph().as_default(): 636 v = variables.VariableV1(1.0) 637 with self.assertRaisesRegex(ValueError, "defer_build"): 638 saver_module.Saver([v], defer_build=True) 639 640 def testBuildShouldBeCalledBeforeSaveInCaseOfDeferBuild(self): 641 save_path = os.path.join(self.get_temp_dir(), "error_deferred_build") 642 with ops_lib.Graph().as_default(), session.Session() as sess: 643 variables.VariableV1(1.0) 644 saver = saver_module.Saver(defer_build=True) 645 with self.assertRaisesRegex(RuntimeError, "build"): 646 saver.save(sess, save_path) 647 648 def testDeferredBuild(self): 649 save_path = os.path.join(self.get_temp_dir(), "deferred_build") 650 with session.Session("", graph=ops_lib.Graph()) as sess: 651 one = variables.VariableV1(1.0) 652 save = saver_module.Saver(defer_build=True) 653 # if build is not deferred, saver cannot save the `twos`. 654 twos = variables.VariableV1([2.0, 2.0, 2.0]) 655 init = variables.global_variables_initializer() 656 save.build() 657 init.run() 658 save.save(sess, save_path) 659 660 with session.Session("", graph=ops_lib.Graph()) as sess: 661 one = variables.VariableV1(0.0) 662 twos = variables.VariableV1([0.0, 0.0, 0.0]) 663 # Saver with no arg, defaults to 'all variables'. 664 save = saver_module.Saver() 665 save.restore(sess, save_path) 666 self.assertAllClose(1.0, self.evaluate(one)) 667 self.assertAllClose([2.0, 2.0, 2.0], self.evaluate(twos)) 668 669 @test_util.run_v1_only("train.Saver is V1 only API.") 670 def testReshape(self): 671 save_path = os.path.join(self.get_temp_dir(), "variables_reshape") 672 with session.Session("", graph=ops_lib.Graph()) as sess: 673 var = variables.VariableV1([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 674 init = variables.global_variables_initializer() 675 save = saver_module.Saver() 676 init.run() 677 save.save(sess, save_path) 678 679 # Error when restoring with default reshape=False 680 with session.Session("", graph=ops_lib.Graph()) as sess: 681 var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) 682 save = saver_module.Saver() 683 with self.assertRaisesRegex( 684 errors_impl.InvalidArgumentError, 685 "Assign requires shapes of both tensors to match."): 686 save.restore(sess, save_path) 687 688 # Restored to new shape with reshape=True 689 with session.Session("", graph=ops_lib.Graph()) as sess: 690 var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) 691 save = saver_module.Saver(reshape=True) 692 save.restore(sess, save_path) 693 self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], 694 self.evaluate(var)) 695 696 @test_util.run_in_graph_and_eager_modes 697 def testSaveWithGlobalStep(self, pad_step_number=False): 698 save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step") 699 global_step_int = 5 700 # Save and reload one Variable named "var0". 701 self._SaveAndLoad("var0", 0.0, 1.0, save_path) 702 for use_tensor in [True, False]: 703 with self.session(graph=ops_lib.Graph()): 704 var = resource_variable_ops.ResourceVariable(1.0, name="var0") 705 save = saver_module.Saver( 706 { 707 var._shared_name: var 708 }, pad_step_number=pad_step_number) 709 if context.executing_eagerly(): 710 sess = None 711 else: 712 self.evaluate(var.initializer) 713 sess = ops_lib.get_default_session() 714 if use_tensor: 715 global_step = constant_op.constant(global_step_int) 716 val = save.save(sess, save_path, global_step=global_step) 717 else: 718 val = save.save(sess, save_path, global_step=global_step_int) 719 if pad_step_number: 720 expected_save_path = "%s-%s" % (save_path, 721 "{:08d}".format(global_step_int)) 722 else: 723 expected_save_path = "%s-%d" % (save_path, global_step_int) 724 self.assertEqual(expected_save_path, val) 725 726 def testSaveWithGlobalStepWithPadding(self): 727 self.testSaveWithGlobalStep(pad_step_number=True) 728 729 def testSaveToNonexistingPath(self): 730 file_io.write_string_to_file( 731 os.path.join(self.get_temp_dir(), "actually_a_file"), "") 732 paths = [ 733 os.path.join(self.get_temp_dir(), "nonexisting_dir/path"), 734 os.path.join(self.get_temp_dir(), "other_nonexisting_dir/path1/path2"), 735 os.path.join(self.get_temp_dir(), "actually_a_file/path"), 736 ] 737 738 for save_path in paths: 739 # Build a graph with 2 parameter nodes, and Save and 740 # Restore nodes for them. 741 v0 = variables.VariableV1(10.0, name="v0") 742 v1 = variables.VariableV1(20.0, name="v1") 743 save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True) 744 init_all_op = variables.global_variables_initializer() 745 746 # In the case where the parent directory doesn't exist, whether or not the 747 # save succeeds or fails is implementation dependent. Therefore we allow 748 # both cases. 749 try: 750 with self.cached_session() as sess: 751 # Initialize all variables 752 self.evaluate(init_all_op) 753 754 # Check that the parameter nodes have been initialized. 755 self.assertEqual(10.0, self.evaluate(v0)) 756 self.assertEqual(20.0, self.evaluate(v1)) 757 758 # Save the graph. 759 save.save(sess, save_path) 760 761 with self.cached_session() as sess: 762 # Restore the saved values in the parameter nodes. 763 save.restore(sess, save_path) 764 # Check that the parameter nodes have been restored. 765 self.assertEqual(10.0, self.evaluate(v0)) 766 self.assertEqual(20.0, self.evaluate(v1)) 767 except ValueError as exc: 768 error_msg_template = "Parent directory of {} doesn't exist, can't save." 769 self.assertEqual(error_msg_template.format(save_path), str(exc)) 770 771 def testSaveToURI(self): 772 # ParseURI functions don't work on Windows yet. 773 # TODO(jhseu): Remove this check when it works. 774 if os.name == "nt": 775 self.skipTest("Local URI support doesn't work on Windows") 776 save_path = "file://" + os.path.join(self.get_temp_dir(), "uri") 777 778 # Build a graph with 2 parameter nodes, and Save and 779 # Restore nodes for them. 780 v0 = variables.VariableV1(10.0, name="v0") 781 v1 = variables.VariableV1(20.0, name="v1") 782 save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True) 783 init_all_op = variables.global_variables_initializer() 784 785 with self.cached_session() as sess: 786 # Initialize all variables 787 self.evaluate(init_all_op) 788 789 # Check that the parameter nodes have been initialized. 790 self.assertEqual(10.0, self.evaluate(v0)) 791 self.assertEqual(20.0, self.evaluate(v1)) 792 save.save(sess, save_path) 793 794 def testSaveRestoreAndValidateVariableDtype(self): 795 for variable_op in [ 796 variables.Variable, resource_variable_ops.ResourceVariable 797 ]: 798 save_path = os.path.join(self.get_temp_dir(), "basic_save_restore") 799 800 # Build the first session. 801 with self.session(graph=ops_lib.Graph()) as sess: 802 v0 = variable_op(10.0, name="v0", dtype=dtypes.float32) 803 804 if not context.executing_eagerly(): 805 self.evaluate([variables.global_variables_initializer()]) 806 807 save = saver_module.Saver({"v0": v0}) 808 save.save(sess, save_path) 809 810 # Start a second session. 811 with self.session(graph=ops_lib.Graph()) as sess: 812 v0_wrong_dtype = variable_op(1, name="v0", dtype=dtypes.int32) 813 # Restore the saved value with different dtype 814 # in the parameter nodes. 815 save = saver_module.Saver({"v0": v0_wrong_dtype}) 816 with self.assertRaisesRegex(errors.InvalidArgumentError, 817 "original dtype"): 818 save.restore(sess, save_path) 819 820 # Test restoring large tensors (triggers a thread pool) 821 def testRestoreLargeTensors(self): 822 save_dir = self.get_temp_dir() 823 def _model(): 824 small_v = [variable_scope.get_variable( 825 "small%d" % i, shape=[10, 2], use_resource=True) for i in range(5)] 826 large_v = [variable_scope.get_variable( 827 "large%d" % i, shape=[32000, 1000], use_resource=True) 828 for i in range(3)] 829 return small_v + large_v 830 831 save_graph = ops_lib.Graph() 832 with save_graph.as_default(), self.session(graph=save_graph) as sess: 833 orig_vars = _model() 834 self.evaluate(variables.global_variables_initializer()) 835 save = saver_module.Saver(max_to_keep=1) 836 self.evaluate(variables.global_variables_initializer()) 837 save.save(sess, save_dir) 838 orig_vals = self.evaluate(orig_vars) 839 840 restore_graph = ops_lib.Graph() 841 with restore_graph.as_default(), self.session( 842 graph=restore_graph) as sess: 843 restored_vars = _model() 844 save = saver_module.Saver(max_to_keep=1) 845 save.restore(sess, save_dir) 846 restored_vals = self.evaluate(restored_vars) 847 848 for orig, restored in zip(orig_vals, restored_vals): 849 self.assertAllEqual(orig, restored) 850 851 852class SaveRestoreShardedTest(test.TestCase): 853 854 _WRITE_VERSION = saver_pb2.SaverDef.V1 855 856 def _get_test_dir(self, dirname): 857 test_dir = os.path.join(self.get_temp_dir(), dirname) 858 gfile.MakeDirs(test_dir) 859 return test_dir 860 861 def testBasics(self): 862 save_path = os.path.join(self.get_temp_dir(), "sharded_basics") 863 864 # Build a graph with 2 parameter nodes on different devices. 865 with session.Session( 866 target="", 867 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 868 with sess.graph.device("/cpu:0"): 869 v0 = variables.VariableV1(10, name="v0") 870 t0 = saver_test_utils.CheckpointedOp(name="t0") 871 with sess.graph.device("/cpu:1"): 872 v1 = variables.VariableV1(20, name="v1") 873 t1 = saver_test_utils.CheckpointedOp(name="t1") 874 save = saver_module.Saver( 875 { 876 "v0": v0, 877 "v1": v1, 878 "t0": t0.saveable, 879 "t1": t1.saveable 880 }, 881 write_version=self._WRITE_VERSION, 882 sharded=True) 883 self.evaluate(variables.global_variables_initializer()) 884 t0.insert("k1", 30.0).run() 885 t1.insert("k2", 40.0).run() 886 val = save.save(sess, save_path) 887 if save._write_version is saver_pb2.SaverDef.V1: 888 self.assertEqual(save_path + "-?????-of-00002", val) 889 else: 890 self.assertEqual(save_path, val) 891 meta_graph_filename = checkpoint_management.meta_graph_filename(val) 892 self.assertEqual(save_path + ".meta", meta_graph_filename) 893 894 if save._write_version is saver_pb2.SaverDef.V1: 895 # Restore different ops from shard 0 of the saved files. 896 with session.Session( 897 target="", 898 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 899 with sess.graph.device("/cpu:0"): 900 v0 = variables.VariableV1(111, name="v0") 901 t0 = saver_test_utils.CheckpointedOp(name="t0") 902 save = saver_module.Saver( 903 { 904 "v0": v0, 905 "t0": t0.saveable 906 }, 907 write_version=self._WRITE_VERSION, 908 sharded=True) 909 self.evaluate(variables.global_variables_initializer()) 910 t0.insert("k11", 33.0).run() 911 self.assertEqual(111, self.evaluate(v0)) 912 self.assertEqual(b"k11", self.evaluate(t0.keys())) 913 self.assertEqual(33.0, self.evaluate(t0.values())) 914 save.restore(sess, save_path + "-00000-of-00002") 915 self.assertEqual(10, self.evaluate(v0)) 916 self.assertEqual(b"k1", self.evaluate(t0.keys())) 917 self.assertEqual(30.0, self.evaluate(t0.values())) 918 919 # Restore different ops from shard 1 of the saved files. 920 with session.Session( 921 target="", 922 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 923 with sess.graph.device("/cpu:0"): 924 v1 = variables.VariableV1(222) 925 t1 = saver_test_utils.CheckpointedOp(name="t1") 926 save = saver_module.Saver( 927 { 928 "v1": v1, 929 "t1": t1.saveable 930 }, 931 write_version=self._WRITE_VERSION, 932 sharded=True) 933 self.evaluate(variables.global_variables_initializer()) 934 t1.insert("k22", 44.0).run() 935 self.assertEqual(222, self.evaluate(v1)) 936 self.assertEqual(b"k22", self.evaluate(t1.keys())) 937 self.assertEqual(44.0, self.evaluate(t1.values())) 938 save.restore(sess, save_path + "-00001-of-00002") 939 self.assertEqual(20, self.evaluate(v1)) 940 self.assertEqual(b"k2", self.evaluate(t1.keys())) 941 self.assertEqual(40.0, self.evaluate(t1.values())) 942 943 # Now try a restore with the sharded filename. 944 with session.Session( 945 target="", 946 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 947 with sess.graph.device("/cpu:0"): 948 v0 = variables.VariableV1(111, name="v0") 949 t0 = saver_test_utils.CheckpointedOp(name="t0") 950 with sess.graph.device("/cpu:1"): 951 v1 = variables.VariableV1(222, name="v1") 952 t1 = saver_test_utils.CheckpointedOp(name="t1") 953 save = saver_module.Saver( 954 { 955 "v0": v0, 956 "v1": v1, 957 "t0": t0.saveable, 958 "t1": t1.saveable 959 }, 960 write_version=self._WRITE_VERSION, 961 sharded=True) 962 self.evaluate(variables.global_variables_initializer()) 963 t0.insert("k11", 33.0).run() 964 t1.insert("k22", 44.0).run() 965 self.assertEqual(111, self.evaluate(v0)) 966 self.assertEqual(222, self.evaluate(v1)) 967 self.assertEqual(b"k11", self.evaluate(t0.keys())) 968 self.assertEqual(33.0, self.evaluate(t0.values())) 969 self.assertEqual(b"k22", self.evaluate(t1.keys())) 970 self.assertEqual(44.0, self.evaluate(t1.values())) 971 save_path = os.path.join(self.get_temp_dir(), "sharded_basics") 972 if save._write_version is saver_pb2.SaverDef.V1: 973 save.restore(sess, save_path + "-?????-of-?????") 974 else: 975 save.restore(sess, save_path) 976 self.assertEqual(10, self.evaluate(v0)) 977 self.assertEqual(20, self.evaluate(v1)) 978 self.assertEqual(b"k1", self.evaluate(t0.keys())) 979 self.assertEqual(30.0, self.evaluate(t0.values())) 980 self.assertEqual(b"k2", self.evaluate(t1.keys())) 981 self.assertEqual(40.0, self.evaluate(t1.values())) 982 983 if save._write_version is saver_pb2.SaverDef.V1: 984 self.assertEqual( 985 checkpoint_management.latest_checkpoint(self.get_temp_dir()), 986 os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002")) 987 else: 988 self.assertEqual( 989 checkpoint_management.latest_checkpoint(self.get_temp_dir()), 990 os.path.join(self.get_temp_dir(), "sharded_basics")) 991 992 def testSaverDef(self): 993 # train.Saver is V1 only API. 994 with ops_lib.Graph().as_default(), self.cached_session(): 995 v0 = variables.VariableV1(123, name="v0") 996 save = saver_module.Saver({"v0": v0}, sharded=True) 997 sd = save.as_saver_def() 998 self.assertTrue(sd.sharded) 999 1000 def _testPartitionedVariables(self, use_resource): 1001 var_full_shape = [10, 3] 1002 # Allows save/restore mechanism to work w/ different slicings. 1003 var_name = "my_var" 1004 saved_dir = self._get_test_dir("partitioned_variables") 1005 saved_path = os.path.join(saved_dir, "ckpt") 1006 1007 call_saver_with_dict = False # updated by test loop below 1008 1009 def _save(partitioner=None): 1010 # train.Saver is V1 only API. 1011 with ops_lib.Graph().as_default(), self.session() as sess: 1012 # Calls .eval() to return the ndarray that makes up the full variable. 1013 rnd = random_ops.random_uniform(var_full_shape).eval() 1014 1015 if partitioner: 1016 vs = [ 1017 variable_scope.get_variable( 1018 var_name, 1019 shape=var_full_shape, 1020 initializer=rnd, 1021 partitioner=partitioner, 1022 use_resource=use_resource) 1023 ] 1024 else: 1025 if use_resource: 1026 vs = [resource_variable_ops.ResourceVariable(rnd, name=var_name)] 1027 else: 1028 vs = [variables.VariableV1(rnd, name=var_name)] 1029 1030 self.evaluate(variables.global_variables_initializer()) 1031 if call_saver_with_dict: 1032 saver = saver_module.Saver({var_name: vs[0]}) 1033 else: 1034 saver = saver_module.Saver(vs) 1035 actual_path = saver.save(sess, saved_path) 1036 self.assertEqual(saved_path, actual_path) 1037 1038 return rnd 1039 1040 def _restore(partitioner=None): 1041 # train.Saver is V1 only API. 1042 with ops_lib.Graph().as_default(), self.session() as sess: 1043 if partitioner: 1044 new_vs = [ 1045 variable_scope.get_variable( 1046 var_name, 1047 shape=var_full_shape, 1048 initializer=array_ops.zeros(var_full_shape), 1049 partitioner=partitioner) 1050 ] 1051 else: 1052 new_vs = [ 1053 variables.VariableV1( 1054 array_ops.zeros( 1055 shape=var_full_shape), # != original contents. 1056 name=var_name) 1057 ] 1058 1059 self.evaluate(variables.global_variables_initializer()) 1060 if call_saver_with_dict: 1061 saver = saver_module.Saver({ 1062 var_name: new_vs[0] 1063 }) 1064 else: 1065 saver = saver_module.Saver(new_vs) 1066 saver.restore(sess, saved_path) 1067 1068 if partitioner: 1069 return new_vs[0].as_tensor().eval() 1070 else: 1071 return new_vs[0].eval() 1072 1073 for call_saver_with_dict in {False, True}: 1074 # Save PartitionedVariable and restore into full variable. 1075 saved_full = _save( 1076 partitioner=partitioned_variables.fixed_size_partitioner( 1077 num_shards=2)) 1078 restored_full = _restore() 1079 self.assertAllEqual(saved_full, restored_full) 1080 1081 # Restores into the same number of partitions. 1082 restored_full = _restore( 1083 partitioner=partitioned_variables.fixed_size_partitioner( 1084 num_shards=2)) 1085 self.assertAllEqual(saved_full, restored_full) 1086 1087 # Restores into a different number of partitions. 1088 restored_full = _restore( 1089 partitioner=partitioned_variables.fixed_size_partitioner( 1090 num_shards=3)) 1091 self.assertAllEqual(saved_full, restored_full) 1092 1093 # Now, saves a full variable and restores PartitionedVariable. 1094 saved_full = _save() 1095 restored_full = _restore( 1096 partitioner=partitioned_variables.fixed_size_partitioner( 1097 num_shards=3)) 1098 self.assertAllEqual(saved_full, restored_full) 1099 1100 def testPartitionedVariable(self): 1101 self._testPartitionedVariables(use_resource=False) 1102 1103 def testPartitionedResourceVariable(self): 1104 self._testPartitionedVariables(use_resource=True) 1105 1106 1107class SaveRestoreShardedTestV2(SaveRestoreShardedTest): 1108 _WRITE_VERSION = saver_pb2.SaverDef.V2 1109 1110 def testIterators(self): 1111 save_path = os.path.join(self.get_temp_dir(), "sharded_iterators") 1112 1113 # Build a graph with 2 parameter nodes on different devices and save. 1114 with session.Session( 1115 target="", 1116 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 1117 with sess.graph.device("/cpu:0"): 1118 ds0 = dataset_ops.Dataset.range(10) 1119 it0 = dataset_ops.make_initializable_iterator(ds0) 1120 get_next0 = it0.get_next() 1121 saveable0 = iterator_ops._IteratorSaveable( 1122 it0._iterator_resource, name="saveable_it0") 1123 1124 with sess.graph.device("/cpu:1"): 1125 ds1 = dataset_ops.Dataset.range(20) 1126 it1 = dataset_ops.make_initializable_iterator(ds1) 1127 get_next1 = it1.get_next() 1128 saveable1 = iterator_ops._IteratorSaveable( 1129 it1._iterator_resource, name="saveable_it1") 1130 saver = saver_module.Saver({ 1131 "it0": saveable0, 1132 "it1": saveable1 1133 }, 1134 write_version=self._WRITE_VERSION, 1135 sharded=True) 1136 self.evaluate(it0.initializer) 1137 self.evaluate(it1.initializer) 1138 self.assertEqual(0, self.evaluate(get_next0)) 1139 self.assertEqual(1, self.evaluate(get_next0)) 1140 self.assertEqual(0, self.evaluate(get_next1)) 1141 val = saver.save(sess, save_path) 1142 self.assertEqual(save_path, val) 1143 data_files = glob.glob(save_path + ".data*") 1144 self.assertEqual(2, len(data_files)) 1145 1146 # Restore 1147 with session.Session( 1148 target="", 1149 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 1150 with sess.graph.device("/cpu:0"): 1151 ds0 = dataset_ops.Dataset.range(10) 1152 it0 = dataset_ops.make_initializable_iterator(ds0) 1153 get_next0 = it0.get_next() 1154 saveable0 = iterator_ops._IteratorSaveable( 1155 it0._iterator_resource, name="saveable_it0") 1156 1157 with sess.graph.device("/cpu:1"): 1158 ds1 = dataset_ops.Dataset.range(20) 1159 it1 = dataset_ops.make_initializable_iterator(ds1) 1160 get_next1 = it1.get_next() 1161 saveable1 = iterator_ops._IteratorSaveable( 1162 it1._iterator_resource, name="saveable_it1") 1163 saver = saver_module.Saver({ 1164 "it0": saveable0, 1165 "it1": saveable1 1166 }, 1167 write_version=self._WRITE_VERSION, 1168 sharded=True) 1169 self.evaluate(it0.initializer) 1170 self.evaluate(it1.initializer) 1171 saver.restore(sess, save_path) 1172 self.assertEqual(2, self.evaluate(get_next0)) 1173 self.assertEqual(1, self.evaluate(get_next1)) 1174 1175 def testIteratorsUnshardedRestore(self): 1176 save_path = os.path.join(self.get_temp_dir(), "restore_unsharded_iterators") 1177 1178 # Build a graph with 2 parameter nodes on different devices and save. 1179 with session.Session( 1180 target="", 1181 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 1182 with sess.graph.device("/cpu:0"): 1183 ds0 = dataset_ops.Dataset.range(10) 1184 it0 = dataset_ops.make_initializable_iterator(ds0) 1185 get_next0 = it0.get_next() 1186 saveable0 = iterator_ops._IteratorSaveable( 1187 it0._iterator_resource, name="saveable_it0") 1188 1189 with sess.graph.device("/cpu:1"): 1190 ds1 = dataset_ops.Dataset.range(20) 1191 it1 = dataset_ops.make_initializable_iterator(ds1) 1192 get_next1 = it1.get_next() 1193 saveable1 = iterator_ops._IteratorSaveable( 1194 it1._iterator_resource, name="saveable_it1") 1195 saver = saver_module.Saver({ 1196 "it0": saveable0, 1197 "it1": saveable1 1198 }, 1199 write_version=self._WRITE_VERSION, 1200 sharded=True) 1201 self.evaluate(it0.initializer) 1202 self.evaluate(it1.initializer) 1203 self.assertEqual(0, self.evaluate(get_next0)) 1204 self.assertEqual(1, self.evaluate(get_next0)) 1205 self.assertEqual(0, self.evaluate(get_next1)) 1206 val = saver.save(sess, save_path) 1207 self.assertEqual(save_path, val) 1208 data_files = glob.glob(save_path + ".data*") 1209 self.assertEqual(2, len(data_files)) 1210 1211 # Restore 1212 with session.Session( 1213 target="", 1214 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 1215 with sess.graph.device("/cpu:0"): 1216 ds0 = dataset_ops.Dataset.range(10) 1217 it0 = dataset_ops.make_initializable_iterator(ds0) 1218 get_next0 = it0.get_next() 1219 saveable0 = iterator_ops._IteratorSaveable( 1220 it0._iterator_resource, name="saveable_it0") 1221 1222 with sess.graph.device("/cpu:1"): 1223 ds1 = dataset_ops.Dataset.range(20) 1224 it1 = dataset_ops.make_initializable_iterator(ds1) 1225 get_next1 = it1.get_next() 1226 saveable1 = iterator_ops._IteratorSaveable( 1227 it1._iterator_resource, name="saveable_it1") 1228 saver = saver_module.Saver({ 1229 "it0": saveable0, 1230 "it1": saveable1 1231 }, 1232 write_version=self._WRITE_VERSION, 1233 sharded=False) 1234 self.evaluate(it0.initializer) 1235 self.evaluate(it1.initializer) 1236 saver.restore(sess, save_path) 1237 self.assertEqual(2, self.evaluate(get_next0)) 1238 self.assertEqual(1, self.evaluate(get_next1)) 1239 1240 1241class MaxToKeepTest(test.TestCase): 1242 1243 def _get_test_dir(self, dirname): 1244 test_dir = os.path.join(self.get_temp_dir(), dirname) 1245 gfile.MakeDirs(test_dir) 1246 return test_dir 1247 1248 def assertCheckpointState(self, model_checkpoint_path, 1249 all_model_checkpoint_paths, save_dir): 1250 checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir) 1251 self.assertEqual(checkpoint_state.model_checkpoint_path, 1252 model_checkpoint_path) 1253 self.assertEqual(checkpoint_state.all_model_checkpoint_paths, 1254 all_model_checkpoint_paths) 1255 1256 def testMaxToKeepEager(self): 1257 with context.eager_mode(): 1258 save_dir = self._get_test_dir("max_to_keep_eager") 1259 1260 v = variable_scope.variable(10.0, name="v") 1261 save = saver_module.Saver({"v": v}, max_to_keep=2) 1262 self.evaluate(variables.global_variables_initializer()) 1263 if not context.executing_eagerly(): 1264 self.assertEqual([], save.last_checkpoints) 1265 1266 s1 = save.save(None, os.path.join(save_dir, "s1")) 1267 self.assertEqual([s1], save.last_checkpoints) 1268 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1269 self.assertCheckpointState( 1270 model_checkpoint_path=s1, 1271 all_model_checkpoint_paths=[s1], 1272 save_dir=save_dir) 1273 1274 s2 = save.save(None, os.path.join(save_dir, "s2")) 1275 self.assertEqual([s1, s2], save.last_checkpoints) 1276 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1277 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1278 self.assertCheckpointState( 1279 model_checkpoint_path=s2, 1280 all_model_checkpoint_paths=[s1, s2], 1281 save_dir=save_dir) 1282 1283 s3 = save.save(None, os.path.join(save_dir, "s3")) 1284 self.assertEqual([s2, s3], save.last_checkpoints) 1285 self.assertFalse(checkpoint_management.checkpoint_exists(s1)) 1286 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1287 self.assertTrue(checkpoint_management.checkpoint_exists(s3)) 1288 self.assertCheckpointState( 1289 model_checkpoint_path=s3, 1290 all_model_checkpoint_paths=[s2, s3], 1291 save_dir=save_dir) 1292 1293 # Create a second helper, identical to the first. 1294 save2 = saver_module.Saver({"v": v}, max_to_keep=2) 1295 save2.set_last_checkpoints(save.last_checkpoints) 1296 1297 # Exercise the first helper. 1298 1299 # Adding s2 again (old s2 is removed first, then new s2 appended) 1300 s2 = save.save(None, os.path.join(save_dir, "s2")) 1301 self.assertEqual([s3, s2], save.last_checkpoints) 1302 self.assertFalse(checkpoint_management.checkpoint_exists(s1)) 1303 self.assertTrue(checkpoint_management.checkpoint_exists(s3)) 1304 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1305 self.assertCheckpointState( 1306 model_checkpoint_path=s2, 1307 all_model_checkpoint_paths=[s3, s2], 1308 save_dir=save_dir) 1309 1310 # Adding s1 (s3 should now be deleted as oldest in list) 1311 s1 = save.save(None, os.path.join(save_dir, "s1")) 1312 self.assertEqual([s2, s1], save.last_checkpoints) 1313 self.assertFalse(checkpoint_management.checkpoint_exists(s3)) 1314 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1315 self.assertCheckpointState( 1316 model_checkpoint_path=s1, 1317 all_model_checkpoint_paths=[s2, s1], 1318 save_dir=save_dir) 1319 1320 s2 = save2.save(None, os.path.join(save_dir, "s2")) 1321 self.assertEqual([s3, s2], save2.last_checkpoints) 1322 # Created by the first helper. 1323 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1324 # Deleted by the first helper. 1325 self.assertFalse(checkpoint_management.checkpoint_exists(s3)) 1326 1327 def testNonSharded(self): 1328 save_dir = self._get_test_dir("max_to_keep_non_sharded") 1329 1330 # train.Saver is V1 only API. 1331 with ops_lib.Graph().as_default(), self.cached_session() as sess: 1332 v = variables.VariableV1(10.0, name="v") 1333 save = saver_module.Saver({"v": v}, max_to_keep=2) 1334 self.evaluate(variables.global_variables_initializer()) 1335 self.assertEqual([], save.last_checkpoints) 1336 1337 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1338 self.assertEqual([s1], save.last_checkpoints) 1339 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1340 self.assertCheckpointState( 1341 model_checkpoint_path=s1, 1342 all_model_checkpoint_paths=[s1], 1343 save_dir=save_dir) 1344 1345 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1346 self.assertEqual([s1, s2], save.last_checkpoints) 1347 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1348 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1349 self.assertCheckpointState( 1350 model_checkpoint_path=s2, 1351 all_model_checkpoint_paths=[s1, s2], 1352 save_dir=save_dir) 1353 1354 s3 = save.save(sess, os.path.join(save_dir, "s3")) 1355 self.assertEqual([s2, s3], save.last_checkpoints) 1356 self.assertFalse(checkpoint_management.checkpoint_exists(s1)) 1357 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1358 self.assertTrue(checkpoint_management.checkpoint_exists(s3)) 1359 self.assertCheckpointState( 1360 model_checkpoint_path=s3, 1361 all_model_checkpoint_paths=[s2, s3], 1362 save_dir=save_dir) 1363 1364 # Create a second helper, identical to the first. 1365 save2 = saver_module.Saver(saver_def=save.as_saver_def()) 1366 save2.set_last_checkpoints(save.last_checkpoints) 1367 1368 # Create a third helper, with the same configuration but no knowledge of 1369 # previous checkpoints. 1370 save3 = saver_module.Saver(saver_def=save.as_saver_def()) 1371 1372 # Exercise the first helper. 1373 1374 # Adding s2 again (old s2 is removed first, then new s2 appended) 1375 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1376 self.assertEqual([s3, s2], save.last_checkpoints) 1377 self.assertFalse(checkpoint_management.checkpoint_exists(s1)) 1378 self.assertFalse( 1379 checkpoint_management.checkpoint_exists( 1380 checkpoint_management.meta_graph_filename(s1))) 1381 self.assertTrue(checkpoint_management.checkpoint_exists(s3)) 1382 self.assertTrue( 1383 checkpoint_management.checkpoint_exists( 1384 checkpoint_management.meta_graph_filename(s3))) 1385 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1386 self.assertTrue( 1387 checkpoint_management.checkpoint_exists( 1388 checkpoint_management.meta_graph_filename(s2))) 1389 self.assertCheckpointState( 1390 model_checkpoint_path=s2, 1391 all_model_checkpoint_paths=[s3, s2], 1392 save_dir=save_dir) 1393 1394 # Adding s1 (s3 should now be deleted as oldest in list) 1395 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1396 self.assertEqual([s2, s1], save.last_checkpoints) 1397 self.assertFalse(checkpoint_management.checkpoint_exists(s3)) 1398 self.assertFalse( 1399 checkpoint_management.checkpoint_exists( 1400 checkpoint_management.meta_graph_filename(s3))) 1401 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1402 self.assertTrue( 1403 checkpoint_management.checkpoint_exists( 1404 checkpoint_management.meta_graph_filename(s2))) 1405 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1406 self.assertTrue( 1407 checkpoint_management.checkpoint_exists( 1408 checkpoint_management.meta_graph_filename(s1))) 1409 self.assertCheckpointState( 1410 model_checkpoint_path=s1, 1411 all_model_checkpoint_paths=[s2, s1], 1412 save_dir=save_dir) 1413 1414 # Exercise the second helper. 1415 1416 # Adding s2 again (old s2 is removed first, then new s2 appended) 1417 s2 = save2.save(sess, os.path.join(save_dir, "s2")) 1418 self.assertEqual([s3, s2], save2.last_checkpoints) 1419 # Created by the first helper. 1420 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1421 self.assertTrue( 1422 checkpoint_management.checkpoint_exists( 1423 checkpoint_management.meta_graph_filename(s1))) 1424 # Deleted by the first helper. 1425 self.assertFalse(checkpoint_management.checkpoint_exists(s3)) 1426 self.assertFalse( 1427 checkpoint_management.checkpoint_exists( 1428 checkpoint_management.meta_graph_filename(s3))) 1429 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1430 self.assertTrue( 1431 checkpoint_management.checkpoint_exists( 1432 checkpoint_management.meta_graph_filename(s2))) 1433 self.assertCheckpointState( 1434 model_checkpoint_path=s2, 1435 all_model_checkpoint_paths=[s3, s2], 1436 save_dir=save_dir) 1437 1438 # Adding s1 (s3 should now be deleted as oldest in list) 1439 s1 = save2.save(sess, os.path.join(save_dir, "s1")) 1440 self.assertEqual([s2, s1], save2.last_checkpoints) 1441 self.assertFalse(checkpoint_management.checkpoint_exists(s3)) 1442 self.assertFalse( 1443 checkpoint_management.checkpoint_exists( 1444 checkpoint_management.meta_graph_filename(s3))) 1445 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1446 self.assertTrue( 1447 checkpoint_management.checkpoint_exists( 1448 checkpoint_management.meta_graph_filename(s2))) 1449 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1450 self.assertTrue( 1451 checkpoint_management.checkpoint_exists( 1452 checkpoint_management.meta_graph_filename(s1))) 1453 self.assertCheckpointState( 1454 model_checkpoint_path=s1, 1455 all_model_checkpoint_paths=[s2, s1], 1456 save_dir=save_dir) 1457 1458 # Exercise the third helper. 1459 1460 # Adding s2 again (but helper is unaware of previous s2) 1461 s2 = save3.save(sess, os.path.join(save_dir, "s2")) 1462 self.assertEqual([s2], save3.last_checkpoints) 1463 # Created by the first helper. 1464 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1465 self.assertTrue( 1466 checkpoint_management.checkpoint_exists( 1467 checkpoint_management.meta_graph_filename(s1))) 1468 # Deleted by the first helper. 1469 self.assertFalse(checkpoint_management.checkpoint_exists(s3)) 1470 self.assertFalse( 1471 checkpoint_management.checkpoint_exists( 1472 checkpoint_management.meta_graph_filename(s3))) 1473 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1474 self.assertTrue( 1475 checkpoint_management.checkpoint_exists( 1476 checkpoint_management.meta_graph_filename(s2))) 1477 # Even though the file for s1 exists, this saver isn't aware of it, which 1478 # is why it doesn't end up in the checkpoint state. 1479 self.assertCheckpointState( 1480 model_checkpoint_path=s2, 1481 all_model_checkpoint_paths=[s2], 1482 save_dir=save_dir) 1483 1484 # Adding s1 (s3 should not be deleted because helper is unaware of it) 1485 s1 = save3.save(sess, os.path.join(save_dir, "s1")) 1486 self.assertEqual([s2, s1], save3.last_checkpoints) 1487 self.assertFalse(checkpoint_management.checkpoint_exists(s3)) 1488 self.assertFalse( 1489 checkpoint_management.checkpoint_exists( 1490 checkpoint_management.meta_graph_filename(s3))) 1491 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1492 self.assertTrue( 1493 checkpoint_management.checkpoint_exists( 1494 checkpoint_management.meta_graph_filename(s2))) 1495 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1496 self.assertTrue( 1497 checkpoint_management.checkpoint_exists( 1498 checkpoint_management.meta_graph_filename(s1))) 1499 self.assertCheckpointState( 1500 model_checkpoint_path=s1, 1501 all_model_checkpoint_paths=[s2, s1], 1502 save_dir=save_dir) 1503 1504 def testSharded(self): 1505 save_dir = self._get_test_dir("max_to_keep_sharded") 1506 1507 with session.Session( 1508 target="", 1509 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 1510 with sess.graph.device("/cpu:0"): 1511 v0 = variables.VariableV1(111, name="v0") 1512 with sess.graph.device("/cpu:1"): 1513 v1 = variables.VariableV1(222, name="v1") 1514 save = saver_module.Saver( 1515 { 1516 "v0": v0, 1517 "v1": v1 1518 }, sharded=True, max_to_keep=2) 1519 self.evaluate(variables.global_variables_initializer()) 1520 self.assertEqual([], save.last_checkpoints) 1521 1522 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1523 self.assertEqual([s1], save.last_checkpoints) 1524 if save._write_version is saver_pb2.SaverDef.V1: 1525 self.assertEqual(2, len(gfile.Glob(s1))) 1526 else: 1527 self.assertEqual(4, len(gfile.Glob(s1 + "*"))) 1528 1529 self.assertTrue( 1530 gfile.Exists(checkpoint_management.meta_graph_filename(s1))) 1531 1532 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1533 self.assertEqual([s1, s2], save.last_checkpoints) 1534 if save._write_version is saver_pb2.SaverDef.V1: 1535 self.assertEqual(2, len(gfile.Glob(s1))) 1536 else: 1537 self.assertEqual(4, len(gfile.Glob(s1 + "*"))) 1538 self.assertTrue( 1539 gfile.Exists(checkpoint_management.meta_graph_filename(s1))) 1540 if save._write_version is saver_pb2.SaverDef.V1: 1541 self.assertEqual(2, len(gfile.Glob(s2))) 1542 else: 1543 self.assertEqual(4, len(gfile.Glob(s2 + "*"))) 1544 self.assertTrue( 1545 gfile.Exists(checkpoint_management.meta_graph_filename(s2))) 1546 1547 s3 = save.save(sess, os.path.join(save_dir, "s3")) 1548 self.assertEqual([s2, s3], save.last_checkpoints) 1549 self.assertEqual(0, len(gfile.Glob(s1 + "*"))) 1550 self.assertFalse( 1551 gfile.Exists(checkpoint_management.meta_graph_filename(s1))) 1552 if save._write_version is saver_pb2.SaverDef.V1: 1553 self.assertEqual(2, len(gfile.Glob(s2))) 1554 else: 1555 self.assertEqual(4, len(gfile.Glob(s2 + "*"))) 1556 self.assertTrue( 1557 gfile.Exists(checkpoint_management.meta_graph_filename(s2))) 1558 if save._write_version is saver_pb2.SaverDef.V1: 1559 self.assertEqual(2, len(gfile.Glob(s3))) 1560 else: 1561 self.assertEqual(4, len(gfile.Glob(s3 + "*"))) 1562 self.assertTrue( 1563 gfile.Exists(checkpoint_management.meta_graph_filename(s3))) 1564 1565 def testNoMaxToKeep(self): 1566 save_dir = self._get_test_dir("no_max_to_keep") 1567 save_dir2 = self._get_test_dir("max_to_keep_0") 1568 1569 with self.cached_session() as sess: 1570 v = variables.VariableV1(10.0, name="v") 1571 self.evaluate(variables.global_variables_initializer()) 1572 1573 # Test max_to_keep being None. 1574 save = saver_module.Saver({"v": v}, max_to_keep=None) 1575 self.assertEqual([], save.last_checkpoints) 1576 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1577 self.assertEqual([], save.last_checkpoints) 1578 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1579 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1580 self.assertEqual([], save.last_checkpoints) 1581 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1582 1583 # Test max_to_keep being 0. 1584 save2 = saver_module.Saver({"v": v}, max_to_keep=0) 1585 self.assertEqual([], save2.last_checkpoints) 1586 s1 = save2.save(sess, os.path.join(save_dir2, "s1")) 1587 self.assertEqual([], save2.last_checkpoints) 1588 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1589 s2 = save2.save(sess, os.path.join(save_dir2, "s2")) 1590 self.assertEqual([], save2.last_checkpoints) 1591 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1592 1593 def testNoMetaGraph(self): 1594 save_dir = self._get_test_dir("no_meta_graph") 1595 1596 with self.cached_session() as sess: 1597 v = variables.VariableV1(10.0, name="v") 1598 save = saver_module.Saver({"v": v}) 1599 self.evaluate(variables.global_variables_initializer()) 1600 1601 s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False) 1602 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1603 self.assertFalse( 1604 gfile.Exists(checkpoint_management.meta_graph_filename(s1))) 1605 1606 1607class RecoverLastCheckpointsTest(test.TestCase): 1608 1609 def _get_test_dir(self, dirname): 1610 test_dir = os.path.join(self.get_temp_dir(), dirname) 1611 gfile.MakeDirs(test_dir) 1612 return test_dir 1613 1614 def assertCheckpointState(self, model_checkpoint_path, 1615 all_model_checkpoint_paths, save_dir): 1616 checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir) 1617 self.assertEqual(checkpoint_state.model_checkpoint_path, 1618 model_checkpoint_path) 1619 self.assertEqual(checkpoint_state.all_model_checkpoint_paths, 1620 all_model_checkpoint_paths) 1621 1622 def test_recover_last_checkpoints(self): 1623 with context.eager_mode(): 1624 save_dir = self._get_test_dir("recover_last_checkpoints") 1625 1626 v = variable_scope.variable(10.0, name="v") 1627 save = saver_module.Saver({"v": v}, max_to_keep=10) 1628 self.evaluate(variables.global_variables_initializer()) 1629 self.assertEqual([], save.last_checkpoints) 1630 1631 s1 = save.save(None, os.path.join(save_dir, "ckpt-1")) 1632 s2 = save.save(None, os.path.join(save_dir, "ckpt-2")) 1633 s3 = save.save(None, os.path.join(save_dir, "ckpt-3")) 1634 self.assertEqual([s1, s2, s3], save.last_checkpoints) 1635 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1636 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1637 self.assertTrue(checkpoint_management.checkpoint_exists(s3)) 1638 self.assertCheckpointState( 1639 model_checkpoint_path=s3, 1640 all_model_checkpoint_paths=[s1, s2, s3], 1641 save_dir=save_dir) 1642 1643 # Create another saver and recover last checkpoints. 1644 save2 = saver_module.Saver({"v": v}, max_to_keep=10) 1645 self.assertEqual([], save2.last_checkpoints) 1646 save2.recover_last_checkpoints([s1, s2, s3]) 1647 self.assertEqual([s1, s2, s3], save2.last_checkpoints) 1648 1649 # Remove a checkpoint and check that last checkpoints are 1650 # restored correctly. 1651 for fname in gfile.Glob("{}*".format(s1)): 1652 gfile.Remove(fname) 1653 self.assertFalse(checkpoint_management.checkpoint_exists(s1)) 1654 1655 # Create another saver and recover last checkpoints. The removed 1656 # checkpoint would be correctly omitted. 1657 save3 = saver_module.Saver({"v": v}, max_to_keep=10) 1658 self.assertEqual([], save3.last_checkpoints) 1659 save3.recover_last_checkpoints([s1, s2, s3]) 1660 self.assertEqual([s2, s3], save3.last_checkpoints) 1661 s4 = save3.save(None, os.path.join(save_dir, "ckpt-4")) 1662 self.assertCheckpointState( 1663 model_checkpoint_path=s4, 1664 all_model_checkpoint_paths=[s2, s3, s4], 1665 save_dir=save_dir) 1666 1667 1668class KeepCheckpointEveryNHoursTest(test.TestCase): 1669 1670 def _get_test_dir(self, dirname): 1671 test_dir = os.path.join(self.get_temp_dir(), dirname) 1672 gfile.MakeDirs(test_dir) 1673 return test_dir 1674 1675 @test_util.run_in_graph_and_eager_modes 1676 @test.mock.patch.object(saver_module, "time") 1677 def testNonSharded(self, mock_time): 1678 save_dir = self._get_test_dir("keep_checkpoint_every_n_hours") 1679 1680 with self.cached_session() as sess: 1681 v = variable_scope.variable([10.0], name="v") 1682 # Run the initializer NOW to avoid the 0.5s overhead of the first Run() 1683 # call, which throws the test timing off in fastbuild mode. 1684 self.evaluate(variables.global_variables_initializer()) 1685 # Create a saver that will keep the last 2 checkpoints plus one every 0.7 1686 # seconds. 1687 start_time = time.time() 1688 mock_time.time.return_value = start_time 1689 save = saver_module.Saver( 1690 { 1691 "v": v 1692 }, max_to_keep=2, keep_checkpoint_every_n_hours=0.7 / 3600) 1693 self.assertEqual([], save.last_checkpoints) 1694 1695 # Wait till 1 seconds have elapsed so s1 will be old enough to keep. 1696 # sleep may return early, don't trust it. 1697 mock_time.time.return_value = start_time + 1.0 1698 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1699 self.assertEqual([s1], save.last_checkpoints) 1700 1701 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1702 self.assertEqual([s1, s2], save.last_checkpoints) 1703 1704 # We now have 2 'last_checkpoints': [s1, s2]. The next call to Save(), 1705 # would normally delete s1, because max_to_keep is 2. However, s1 is 1706 # older than 0.7s so we must keep it. 1707 s3 = save.save(sess, os.path.join(save_dir, "s3")) 1708 self.assertEqual([s2, s3], save.last_checkpoints) 1709 1710 # s1 should still be here, we are Not checking now to reduce time 1711 # variance in the test. 1712 1713 # We now have 2 'last_checkpoints': [s2, s3], and s1 on disk. The next 1714 # call to Save(), will delete s2, because max_to_keep is 2, and because 1715 # we already kept the old s1. s2 is very close in time to s1 so it gets 1716 # deleted. 1717 s4 = save.save(sess, os.path.join(save_dir, "s4")) 1718 self.assertEqual([s3, s4], save.last_checkpoints) 1719 1720 # Check that s1 is still here, but s2 is gone. 1721 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1722 self.assertFalse(checkpoint_management.checkpoint_exists(s2)) 1723 self.assertTrue(checkpoint_management.checkpoint_exists(s3)) 1724 self.assertTrue(checkpoint_management.checkpoint_exists(s4)) 1725 1726 1727class SaveRestoreWithVariableNameMap(test.TestCase): 1728 1729 def _testNonReshape(self, variable_op): 1730 save_path = os.path.join(self.get_temp_dir(), "non_reshape") 1731 1732 with self.session(graph=ops_lib.Graph()) as sess: 1733 # Build a graph with 2 parameter nodes, and Save and 1734 # Restore nodes for them. 1735 v0 = variable_op(10.0, name="v0") 1736 v1 = variable_op(20.0, name="v1") 1737 save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1}) 1738 self.evaluate(variables.global_variables_initializer()) 1739 1740 # Check that the parameter nodes have been initialized. 1741 self.assertEqual(10.0, self.evaluate(v0)) 1742 self.assertEqual(20.0, self.evaluate(v1)) 1743 1744 # Save the initialized values in the file at "save_path" 1745 # Use a variable name map to set the saved tensor names 1746 val = save.save(sess, save_path) 1747 self.assertTrue(isinstance(val, six.string_types)) 1748 self.assertEqual(save_path, val) 1749 1750 # Verify that the original names are not in the Saved file 1751 save = saver_module.Saver({"v0": v0, "v1": v1}) 1752 with self.assertRaisesOpError("not found in checkpoint"): 1753 save.restore(sess, save_path) 1754 1755 # Verify that the mapped names are present in the Saved file and can be 1756 # Restored using remapped names. 1757 with self.session(graph=ops_lib.Graph()) as sess: 1758 v0 = variable_op(-1.0, name="v0") 1759 v1 = variable_op(-1.0, name="v1") 1760 1761 if not context.executing_eagerly(): 1762 with self.assertRaisesOpError("uninitialized"): 1763 self.evaluate(v0) 1764 with self.assertRaisesOpError("uninitialized"): 1765 self.evaluate(v1) 1766 1767 save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1}) 1768 save.restore(sess, save_path) 1769 1770 # Check that the parameter nodes have been restored. 1771 if not context.executing_eagerly(): 1772 self.assertEqual(10.0, self.evaluate(v0)) 1773 self.assertEqual(20.0, self.evaluate(v1)) 1774 1775 # Add a prefix to the node names in the current graph and Restore using 1776 # remapped names. 1777 with self.session(graph=ops_lib.Graph()) as sess: 1778 v0 = variable_op(-1.0, name="restore_prefix/v0") 1779 v1 = variable_op(-1.0, name="restore_prefix/v1") 1780 1781 if not context.executing_eagerly(): 1782 with self.assertRaisesOpError("uninitialized"): 1783 self.evaluate(v0) 1784 with self.assertRaisesOpError("uninitialized"): 1785 self.evaluate(v1) 1786 1787 # Restore the saved values in the parameter nodes. 1788 save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1}) 1789 save.restore(sess, save_path) 1790 1791 # Check that the parameter nodes have been restored. 1792 self.assertEqual(10.0, self.evaluate(v0)) 1793 self.assertEqual(20.0, self.evaluate(v1)) 1794 1795 @test_util.run_in_graph_and_eager_modes 1796 def testNonReshapeResourceVariable(self): 1797 self._testNonReshape(resource_variable_ops.ResourceVariable) 1798 1799 def testNonReshapeVariable(self): 1800 self._testNonReshape(variables.Variable) 1801 1802 1803class MetaGraphTest(test.TestCase): 1804 1805 def _get_test_dir(self, dirname): 1806 test_dir = os.path.join(self.get_temp_dir(), dirname) 1807 gfile.MakeDirs(test_dir) 1808 return test_dir 1809 1810 @test_util.run_v1_only( 1811 "Queue-based input pipelines have been replaced by `tf.data` " 1812 "and not supported in V2.") 1813 def testAddCollectionDef(self): 1814 test_dir = self._get_test_dir("good_collection") 1815 filename = os.path.join(test_dir, "metafile") 1816 with self.cached_session(): 1817 # Creates a graph. 1818 v0 = variables.VariableV1(1.0, name="v0") 1819 control_flow_ops.cond( 1820 math_ops.less(v0, 10), lambda: math_ops.add(v0, 1), 1821 lambda: math_ops.subtract(v0, 1)) 1822 control_flow_ops.while_loop(lambda i: math_ops.less(i, 10), 1823 lambda i: math_ops.add(i, 1), [v0]) 1824 var = variables.VariableV1(constant_op.constant(0, dtype=dtypes.int64)) 1825 count_up_to = var.count_up_to(3) 1826 input_queue = data_flow_ops.FIFOQueue( 1827 30, dtypes.float32, shared_name="collection_queue") 1828 qr = queue_runner_impl.QueueRunner(input_queue, [count_up_to]) 1829 variables.global_variables_initializer() 1830 # Creates a saver. 1831 save = saver_module.Saver({"v0": v0}) 1832 # Adds a set of collections. 1833 ops_lib.add_to_collection("int_collection", 3) 1834 ops_lib.add_to_collection("float_collection", 3.5) 1835 ops_lib.add_to_collection("string_collection", "hello") 1836 ops_lib.add_to_collection("variable_collection", v0) 1837 # Add QueueRunners. 1838 queue_runner_impl.add_queue_runner(qr) 1839 # Adds user_defined proto in three formats: string, bytes and Any. 1840 queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue") 1841 ops_lib.add_to_collection("user_defined_string_collection", 1842 str(queue_runner)) 1843 ops_lib.add_to_collection("user_defined_bytes_collection", 1844 queue_runner.SerializeToString()) 1845 any_buf = Any() 1846 any_buf.Pack(queue_runner) 1847 ops_lib.add_to_collection("user_defined_any_collection", any_buf) 1848 1849 # Generates MetaGraphDef. 1850 meta_graph_def = save.export_meta_graph(filename) 1851 self.assertTrue(meta_graph_def.HasField("saver_def")) 1852 self.assertTrue(meta_graph_def.HasField("graph_def")) 1853 self.assertTrue(meta_graph_def.HasField("meta_info_def")) 1854 self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "") 1855 self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version, 1856 "") 1857 collection_def = meta_graph_def.collection_def 1858 self.assertEqual(len(collection_def), 12) 1859 1860 with ops_lib.Graph().as_default(): 1861 # Restores from MetaGraphDef. 1862 new_saver = saver_module.import_meta_graph(filename) 1863 # Generates a new MetaGraphDef. 1864 new_meta_graph_def = new_saver.export_meta_graph() 1865 # It should be the same as the original. 1866 1867 test_util.assert_meta_graph_protos_equal( 1868 self, meta_graph_def, new_meta_graph_def) 1869 1870 def testAddCollectionDefFails(self): 1871 with self.cached_session(): 1872 # Creates a graph. 1873 v0 = variables.VariableV1(10.0, name="v0") 1874 # Creates a saver. 1875 save = saver_module.Saver({"v0": v0}) 1876 # Generates MetaGraphDef. 1877 meta_graph_def = meta_graph_pb2.MetaGraphDef() 1878 1879 # Verifies that collection with unsupported key will not be added. 1880 ops_lib.add_to_collection(save, 3) 1881 save._add_collection_def(meta_graph_def, save) 1882 self.assertEqual(len(meta_graph_def.collection_def), 0) 1883 1884 # Verifies that collection where item type does not match expected 1885 # type will not be added. 1886 ops_lib.add_to_collection("int_collection", 3) 1887 ops_lib.add_to_collection("int_collection", 3.5) 1888 save._add_collection_def(meta_graph_def, "int_collection") 1889 self.assertEqual(len(meta_graph_def.collection_def), 0) 1890 1891 def _testMultiSaverCollectionSave(self, test_dir): 1892 filename = os.path.join(test_dir, "metafile") 1893 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 1894 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt") 1895 with self.session(graph=ops_lib.Graph()) as sess: 1896 # Creates a graph. 1897 v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0") 1898 v1 = variables.VariableV1(11.0, name="v1") 1899 # Creates 2 savers. 1900 saver0 = saver_module.Saver({"v0": v0}, name="saver0") 1901 saver1 = saver_module.Saver({"v1": v1}, name="saver1") 1902 ops_lib.add_to_collection("savers", saver0) 1903 ops_lib.add_to_collection("savers", saver1) 1904 self.evaluate(variables.global_variables_initializer()) 1905 # Saves to different checkpoints. 1906 saver0.save(sess, saver0_ckpt) 1907 saver1.save(sess, saver1_ckpt) 1908 # Generates MetaGraphDef. 1909 meta_graph_def = saver_module.export_meta_graph(filename) 1910 meta_graph_def0 = saver0.export_meta_graph() 1911 meta_graph_def1 = saver1.export_meta_graph() 1912 1913 # Verifies that there is no saver_def in meta_graph_def. 1914 self.assertFalse(meta_graph_def.HasField("saver_def")) 1915 # Verifies that there is saver_def in meta_graph_def0 and 1. 1916 self.assertTrue(meta_graph_def0.HasField("saver_def")) 1917 self.assertTrue(meta_graph_def1.HasField("saver_def")) 1918 1919 # Verifies SAVERS is saved as bytes_list for meta_graph_def. 1920 collection_def = meta_graph_def.collection_def["savers"] 1921 kind = collection_def.WhichOneof("kind") 1922 self.assertEqual(kind, "bytes_list") 1923 # Verifies that there are 2 entries in SAVERS collection. 1924 savers = getattr(collection_def, kind) 1925 self.assertEqual(2, len(savers.value)) 1926 1927 # Verifies SAVERS collection is saved as bytes_list for meta_graph_def0. 1928 collection_def = meta_graph_def0.collection_def["savers"] 1929 kind = collection_def.WhichOneof("kind") 1930 self.assertEqual(kind, "bytes_list") 1931 # Verifies that there are 2 entries in SAVERS collection. 1932 savers = getattr(collection_def, kind) 1933 self.assertEqual(2, len(savers.value)) 1934 1935 def _testMultiSaverCollectionRestore(self, test_dir): 1936 filename = os.path.join(test_dir, "metafile") 1937 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 1938 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt") 1939 with self.session(graph=ops_lib.Graph()) as sess: 1940 # Imports from meta_graph. 1941 saver_module.import_meta_graph(filename) 1942 # Retrieves SAVERS collection. Verifies there are 2 entries. 1943 savers = ops_lib.get_collection("savers") 1944 self.assertEqual(2, len(savers)) 1945 # Retrieves saver0. Verifies that new_saver0 can restore v0, but not v1. 1946 new_saver0 = savers[0] 1947 new_saver0.restore(sess, saver0_ckpt) 1948 v0 = sess.graph.get_tensor_by_name("v0:0") 1949 v1 = sess.graph.get_tensor_by_name("v1:0") 1950 self.assertAllEqual([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], 1951 self.evaluate(v0)) 1952 self.assertEqual([3, 2], v0.get_shape()) 1953 self.assertEqual([], v1.get_shape()) 1954 with self.assertRaisesWithPredicateMatch( 1955 errors_impl.OpError, lambda e: "uninitialized value v1" in e.message): 1956 self.evaluate(v1) 1957 # Retrieves saver1. Verifies that new_saver1 can restore v1. 1958 new_saver1 = savers[1] 1959 new_saver1.restore(sess, saver1_ckpt) 1960 v1 = sess.graph.get_tensor_by_name("v1:0") 1961 self.assertEqual(11.0, self.evaluate(v1)) 1962 1963 @test_util.run_v1_only( 1964 "Exporting/importing meta graphs is only supported in V1.") 1965 def testMultiSaverCollection(self): 1966 test_dir = self._get_test_dir("saver_collection") 1967 self._testMultiSaverCollectionSave(test_dir) 1968 self._testMultiSaverCollectionRestore(test_dir) 1969 1970 @test_util.run_v1_only( 1971 "Exporting/importing meta graphs is only supported in V1.") 1972 def testClearExtraneousSavers(self): 1973 test_dir = self._get_test_dir("clear_extraneous_savers") 1974 filename = os.path.join(test_dir, "metafile") 1975 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 1976 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt") 1977 with self.session(graph=ops_lib.Graph()) as sess: 1978 # Creates a graph. 1979 v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0") 1980 v1 = variables.VariableV1(11.0, name="v1") 1981 1982 # Creates 2 savers. 1983 saver0 = saver_module.Saver({"v0": v0}, name="saver0") 1984 saver1 = saver_module.Saver({"v1": v1}, name="saver1") 1985 ops_lib.add_to_collection("savers", saver0) 1986 ops_lib.add_to_collection("savers", saver1) 1987 self.evaluate(variables.global_variables_initializer()) 1988 1989 # Saves to different checkpoints. 1990 saver0.save(sess, saver0_ckpt) 1991 saver1.save(sess, saver1_ckpt) 1992 1993 # Generates MetaGraphDef. 1994 meta_graph_def = saver_module.export_meta_graph(filename) 1995 meta_graph_def0 = saver0.export_meta_graph() 1996 meta_graph_def1 = saver1.export_meta_graph(clear_extraneous_savers=True) 1997 1998 # Verifies that there is no saver_def in meta_graph_def. 1999 self.assertFalse(meta_graph_def.HasField("saver_def")) 2000 # Verifies that there is saver_def in meta_graph_def0 and 1. 2001 self.assertTrue(meta_graph_def0.HasField("saver_def")) 2002 self.assertTrue(meta_graph_def1.HasField("saver_def")) 2003 2004 # Verifies SAVERS is saved as bytes_list for meta_graph_def. 2005 collection_def = meta_graph_def.collection_def["savers"] 2006 kind = collection_def.WhichOneof("kind") 2007 self.assertEqual(kind, "bytes_list") 2008 2009 # Verifies that there are 2 entries in SAVERS collection. 2010 savers = getattr(collection_def, kind) 2011 self.assertEqual(2, len(savers.value)) 2012 2013 # Verifies SAVERS collection is saved as bytes_list for meta_graph_def1. 2014 collection_def = meta_graph_def1.collection_def["savers"] 2015 kind = collection_def.WhichOneof("kind") 2016 self.assertEqual(kind, "bytes_list") 2017 2018 # Verifies that there is 1 entry in SAVERS collection. 2019 savers = getattr(collection_def, kind) 2020 self.assertEqual(1, len(savers.value)) 2021 2022 # Verifies that saver0 graph nodes are omitted from the saver1 export 2023 self.assertEqual(33, len(meta_graph_def0.graph_def.node)) 2024 self.assertEqual(21, len(meta_graph_def1.graph_def.node)) 2025 2026 def testBinaryAndTextFormat(self): 2027 test_dir = self._get_test_dir("binary_and_text") 2028 filename = os.path.join(test_dir, "metafile") 2029 # train.Saver is V1 only API. 2030 with ops_lib.Graph().as_default(), self.session(): 2031 # Creates a graph. 2032 variables.VariableV1(10.0, name="v0") 2033 # Exports the graph as binary format. 2034 saver_module.export_meta_graph(filename, as_text=False) 2035 with ops_lib.Graph().as_default(), self.session(): 2036 # Imports the binary format graph. 2037 saver = saver_module.import_meta_graph(filename) 2038 self.assertIsNotNone(saver) 2039 # Exports the graph as text format. 2040 saver.export_meta_graph(filename, as_text=True) 2041 with ops_lib.Graph().as_default(), self.session(): 2042 # Imports the text format graph. 2043 saver_module.import_meta_graph(filename) 2044 # Writes wrong contents to the file. 2045 graph_io.write_graph(saver.as_saver_def(), 2046 os.path.dirname(filename), 2047 os.path.basename(filename)) 2048 with ops_lib.Graph().as_default(), self.session(): 2049 # Import should fail. 2050 with self.assertRaisesWithPredicateMatch(IOError, 2051 lambda e: "Cannot parse file"): 2052 saver_module.import_meta_graph(filename) 2053 # Deletes the file 2054 gfile.Remove(filename) 2055 with self.assertRaisesWithPredicateMatch(IOError, 2056 lambda e: "does not exist"): 2057 saver_module.import_meta_graph(filename) 2058 2059 @test_util.run_v1_only( 2060 "Exporting/importing meta graphs is only supported in V1.") 2061 def testSliceVariable(self): 2062 test_dir = self._get_test_dir("slice_saver") 2063 filename = os.path.join(test_dir, "metafile") 2064 with self.cached_session(): 2065 v1 = variables.VariableV1([20.0], name="v1") 2066 v2 = variables.VariableV1([20.0], name="v2") 2067 v2._set_save_slice_info( 2068 variables.Variable.SaveSliceInfo("v1", [1], [0], [1])) 2069 2070 # The names are different and will work. 2071 slice_saver = saver_module.Saver({"first": v1, "second": v2}) 2072 self.evaluate(variables.global_variables_initializer()) 2073 # Exports to meta_graph 2074 meta_graph_def = slice_saver.export_meta_graph(filename) 2075 2076 with ops_lib.Graph().as_default(): 2077 # Restores from MetaGraphDef. 2078 new_saver = saver_module.import_meta_graph(filename) 2079 self.assertIsNotNone(new_saver) 2080 # Generates a new MetaGraphDef. 2081 new_meta_graph_def = new_saver.export_meta_graph() 2082 # It should be the same as the original. 2083 test_util.assert_meta_graph_protos_equal(self, meta_graph_def, 2084 new_meta_graph_def) 2085 2086 def _testGraphExtensionSave(self, test_dir): 2087 filename = os.path.join(test_dir, "metafile") 2088 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 2089 # Creates an inference graph. 2090 # Hidden 1 2091 images = constant_op.constant(1.2, dtypes.float32, shape=[100, 28]) 2092 with ops_lib.name_scope("hidden1"): 2093 weights = variables.VariableV1( 2094 random_ops.truncated_normal( 2095 [28, 128], stddev=1.0 / math.sqrt(float(28))), 2096 name="weights") 2097 # The use of control_flow_ops.cond here is purely for adding test coverage 2098 # the save and restore of control flow context (which doesn't make any 2099 # sense here from a machine learning perspective). The typical biases is 2100 # a simple Variable without the conditions. 2101 biases = variables.VariableV1( 2102 control_flow_ops.cond( 2103 math_ops.less(random.random(), 0.5), 2104 lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])), 2105 name="biases") 2106 hidden1 = nn_ops.relu(math_ops.matmul(images, weights) + biases) 2107 # Hidden 2 2108 with ops_lib.name_scope("hidden2"): 2109 weights = variables.VariableV1( 2110 random_ops.truncated_normal( 2111 [128, 32], stddev=1.0 / math.sqrt(float(128))), 2112 name="weights") 2113 2114 # The use of control_flow_ops.while_loop here is purely for adding test 2115 # coverage the save and restore of control flow context (which doesn't 2116 # make any sense here from a machine learning perspective). The typical 2117 # biases is a simple Variable without the conditions. 2118 def loop_cond(it, _): 2119 return it < 2 2120 2121 def loop_body(it, biases): 2122 biases += constant_op.constant(0.1, shape=[32]) 2123 return it + 1, biases 2124 2125 _, biases = control_flow_ops.while_loop( 2126 loop_cond, loop_body, 2127 [constant_op.constant(0), 2128 variables.VariableV1(array_ops.zeros([32]))]) 2129 hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases) 2130 # Linear 2131 with ops_lib.name_scope("softmax_linear"): 2132 weights = variables.VariableV1( 2133 random_ops.truncated_normal( 2134 [32, 10], stddev=1.0 / math.sqrt(float(32))), 2135 name="weights") 2136 biases = variables.VariableV1(array_ops.zeros([10]), name="biases") 2137 logits = math_ops.matmul(hidden2, weights) + biases 2138 ops_lib.add_to_collection("logits", logits) 2139 init_all_op = variables.global_variables_initializer() 2140 2141 with self.cached_session() as sess: 2142 # Initializes all the variables. 2143 self.evaluate(init_all_op) 2144 # Runs to logit. 2145 self.evaluate(logits) 2146 # Creates a saver. 2147 saver0 = saver_module.Saver() 2148 saver0.save(sess, saver0_ckpt) 2149 # Generates MetaGraphDef. 2150 saver0.export_meta_graph(filename) 2151 2152 def _testGraphExtensionRestore(self, test_dir): 2153 filename = os.path.join(test_dir, "metafile") 2154 train_filename = os.path.join(test_dir, "train_metafile") 2155 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 2156 with self.session(graph=ops_lib.Graph()) as sess: 2157 # Restores from MetaGraphDef. 2158 new_saver = saver_module.import_meta_graph(filename) 2159 # Generates a new MetaGraphDef. 2160 new_saver.export_meta_graph() 2161 # Restores from checkpoint. 2162 new_saver.restore(sess, saver0_ckpt) 2163 # Adds loss and train. 2164 labels = constant_op.constant(0, dtypes.int32, shape=[100], name="labels") 2165 batch_size = array_ops.size(labels) 2166 labels = array_ops.expand_dims(labels, 1) 2167 indices = array_ops.expand_dims(math_ops.range(0, batch_size), 1) 2168 concated = array_ops.concat([indices, labels], 1) 2169 onehot_labels = sparse_ops.sparse_to_dense( 2170 concated, array_ops.stack([batch_size, 10]), 1.0, 0.0) 2171 logits = ops_lib.get_collection("logits")[0] 2172 cross_entropy = nn_ops.softmax_cross_entropy_with_logits( 2173 labels=onehot_labels, logits=logits, name="xentropy") 2174 loss = math_ops.reduce_mean(cross_entropy, name="xentropy_mean") 2175 2176 summary.scalar("loss", loss) 2177 # Creates the gradient descent optimizer with the given learning rate. 2178 optimizer = gradient_descent.GradientDescentOptimizer(0.01) 2179 2180 # Runs train_op. 2181 train_op = optimizer.minimize(loss) 2182 ops_lib.add_to_collection("train_op", train_op) 2183 2184 # Runs train_op. 2185 self.evaluate(train_op) 2186 2187 # Generates MetaGraphDef. 2188 saver_module.export_meta_graph(train_filename) 2189 2190 def _testRestoreFromTrainGraphWithControlContext(self, test_dir): 2191 train_filename = os.path.join(test_dir, "train_metafile") 2192 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 2193 with self.session(graph=ops_lib.Graph()) as sess: 2194 # Restores from MetaGraphDef. 2195 new_saver = saver_module.import_meta_graph(train_filename) 2196 # Restores from checkpoint. 2197 new_saver.restore(sess, saver0_ckpt) 2198 train_op = ops_lib.get_collection("train_op")[0] 2199 self.evaluate(train_op) 2200 2201 def testGraphExtension(self): 2202 test_dir = self._get_test_dir("graph_extension") 2203 # train.Saver and train.import_meta_graph are V1 only APIs. 2204 with ops_lib.Graph().as_default(): 2205 self._testGraphExtensionSave(test_dir) 2206 self._testGraphExtensionRestore(test_dir) 2207 self._testRestoreFromTrainGraphWithControlContext(test_dir) 2208 2209 def _testGradientSerDes(self, graph_fn): 2210 """Tests that gradients can be computed after exporting and importing. 2211 2212 Builds a graph, exports it, and verifies that it can be imported and the 2213 gradient can be built and run correctly. 2214 2215 Args: 2216 graph_fn: takes a single float Tensor argument as input, outputs a single 2217 Tensor 2218 """ 2219 test_dir = self._get_test_dir("nested_control_flow") 2220 filename = os.path.join(test_dir, "metafile") 2221 saver_ckpt = os.path.join(test_dir, "saver.ckpt") 2222 2223 # Create while loop using `outer_body_fn`. 2224 with ops_lib.Graph().as_default(): 2225 var = variables.VariableV1(0.0) 2226 var_name = var.name 2227 output = graph_fn(var) 2228 output_name = output.name 2229 init_op = variables.global_variables_initializer() 2230 2231 # Generate a MetaGraphDef containing the while loop. 2232 with session.Session() as sess: 2233 self.evaluate(init_op) 2234 self.evaluate(output) 2235 saver = saver_module.Saver() 2236 saver.save(sess, saver_ckpt) 2237 saver.export_meta_graph(filename) 2238 2239 # Build and run the gradients of the while loop. We use this below to 2240 # verify that the gradients are correct with an imported MetaGraphDef. 2241 grad = gradients_impl.gradients([output], [var]) 2242 # Turn off constant folding to avoid breaking testNestedControlFlowSerDes. 2243 # It appears that a missing control dependency in the gradient graph 2244 # causes the fetch node to not be triggered. 2245 no_constfold_config = config_pb2.ConfigProto() 2246 no_constfold_config.graph_options.rewrite_options.constant_folding = ( 2247 rewriter_config_pb2.RewriterConfig.OFF) 2248 with session.Session(config=no_constfold_config) as sess: 2249 self.evaluate(init_op) 2250 expected_grad_value = self.evaluate(grad) 2251 2252 # Restore the MetaGraphDef into a new Graph. 2253 with ops_lib.Graph().as_default(): 2254 with session.Session() as sess: 2255 saver = saver_module.import_meta_graph(filename) 2256 saver.restore(sess, saver_ckpt) 2257 2258 # Make sure we can still build gradients and get the same result. 2259 var = ops_lib.get_default_graph().get_tensor_by_name(var_name) 2260 output = ops_lib.get_default_graph().get_tensor_by_name(output_name) 2261 grad = gradients_impl.gradients([output], [var]) 2262 2263 init_op = variables.global_variables_initializer() 2264 2265 with session.Session(config=no_constfold_config) as sess: 2266 self.evaluate(init_op) 2267 actual_grad_value = self.evaluate(grad) 2268 self.assertEqual(expected_grad_value, actual_grad_value) 2269 2270 def _testWhileLoopAndGradientSerDes(self, outer_body_fn): 2271 # Build a while loop with `outer_body_fn`, export it, and verify that it can 2272 # be imported and the gradient can be built and run correctly. 2273 # pylint: disable=g-long-lambda 2274 return self._testGradientSerDes( 2275 lambda x: control_flow_ops.while_loop( 2276 lambda i, y: i < 5, outer_body_fn, [0, x])[1]) 2277 # pylint: enable=g-long-lambda 2278 2279 def testNestedWhileLoopsSerDes(self): 2280 # Test two simple nested while loops. 2281 def body(i, x): 2282 _, r = control_flow_ops.while_loop(lambda j, y: j < 3, 2283 lambda j, y: (j + 1, y + x), 2284 [0, 0.0]) 2285 return i + 1, x + r 2286 self._testWhileLoopAndGradientSerDes(body) 2287 2288 def testNestedControlFlowSerDes(self): 2289 # Test while loop in a cond in a while loop. 2290 # pylint: disable=g-long-lambda 2291 def body(i, x): 2292 cond_result = control_flow_ops.cond( 2293 i > 0, 2294 lambda: control_flow_ops.while_loop( 2295 lambda j, y: j < 3, 2296 lambda j, y: (j + 1, y + x), 2297 [0, 0.0])[1], 2298 lambda: x) 2299 return i + 1, cond_result 2300 # pylint: enable=g-long-lambda 2301 self._testWhileLoopAndGradientSerDes(body) 2302 2303 def testNestedCondsSerDes(self): 2304 # Test conds in a cond. 2305 # pylint: disable=g-long-lambda 2306 self._testGradientSerDes(lambda x: control_flow_ops.cond( 2307 x > 0, 2308 lambda: control_flow_ops.cond(x > 3, 2309 lambda: array_ops.identity(x), 2310 lambda: math_ops.multiply(x, 2.0)), 2311 lambda: control_flow_ops.cond(x < -3, 2312 lambda: constant_op.constant(1.0), 2313 lambda: math_ops.multiply(x, -1.0)))) 2314 # pylint: enable=g-long-lambda 2315 2316 @test_util.run_v1_only("This exercises Tensor.op which is meaningless in V2.") 2317 def testStrippedOpListDef(self): 2318 with self.cached_session(): 2319 # Creates a graph. 2320 v0 = variables.VariableV1(0.0) 2321 var = variables.VariableV1(10.0) 2322 math_ops.add(v0, var) 2323 2324 @function.Defun(dtypes.float32) 2325 def minus_one(x): 2326 return x - 1 2327 2328 minus_one(array_ops.identity(v0)) 2329 save = saver_module.Saver({"v0": v0}) 2330 variables.global_variables_initializer() 2331 2332 # Generates MetaGraphDef. 2333 meta_graph_def = save.export_meta_graph() 2334 ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op] 2335 if save._write_version is saver_pb2.SaverDef.V1: 2336 self.assertEqual(ops, [ 2337 "Add", "Assign", "Const", "Identity", "NoOp", 2338 "PlaceholderWithDefault", "RestoreV2", "SaveSlices", "Sub", 2339 "VariableV2" 2340 ]) 2341 else: 2342 self.assertEqual(ops, [ 2343 "Add", "Assign", "Const", "Identity", "NoOp", 2344 "PlaceholderWithDefault", "RestoreV2", "SaveV2", "Sub", "VariableV2" 2345 ]) 2346 2347 # Test calling stripped_op_list_for_graph directly 2348 op_list = meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def) 2349 self.assertEqual(ops, [o.name for o in op_list.op]) 2350 for o in op_list.op: 2351 self.assertEqual(o.summary, "") 2352 self.assertEqual(o.description, "") 2353 2354 def testStripDefaultValuedAttrs(self): 2355 """Verifies that default valued attrs are stripped, unless disabled.""" 2356 2357 # With strip_default_attrs enabled, attributes "T" (float32) and "Tout" 2358 # (complex64) in the "Complex" op must be removed. 2359 # train.Saver and train.export_meta_graph are V1 only APIs. 2360 with ops_lib.Graph().as_default(), self.cached_session(): 2361 real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real") 2362 imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag") 2363 math_ops.complex(real_num, imag_num, name="complex") 2364 2365 save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num}) 2366 variables.global_variables_initializer() 2367 2368 meta_graph_def = save.export_meta_graph(strip_default_attrs=True) 2369 node_def = test_util.get_node_def_from_graph("complex", 2370 meta_graph_def.graph_def) 2371 self.assertNotIn("T", node_def.attr) 2372 self.assertNotIn("Tout", node_def.attr) 2373 2374 # With strip_default_attrs disabled, attributes "T" (float32) and "Tout" 2375 # (complex64) in the "Complex" op must *not* be removed, even if they map 2376 # to their defaults. 2377 with ops_lib.Graph().as_default(), self.session(): 2378 real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real") 2379 imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag") 2380 math_ops.complex(real_num, imag_num, name="complex") 2381 2382 save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num}) 2383 variables.global_variables_initializer() 2384 2385 meta_graph_def = save.export_meta_graph(strip_default_attrs=False) 2386 node_def = test_util.get_node_def_from_graph("complex", 2387 meta_graph_def.graph_def) 2388 self.assertIn("T", node_def.attr) 2389 self.assertIn("Tout", node_def.attr) 2390 2391 def testImportIntoNamescope(self): 2392 # Test that we can import a meta graph into a namescope. 2393 test_dir = self._get_test_dir("import_into_namescope") 2394 filename = os.path.join(test_dir, "ckpt") 2395 # train.Saver is V1 only API. 2396 with ops_lib.Graph().as_default(): 2397 image = array_ops.placeholder(dtypes.float32, [None, 784], name="image") 2398 label = array_ops.placeholder(dtypes.float32, [None, 10], name="label") 2399 with session.Session() as sess: 2400 weights = variables.VariableV1( 2401 random_ops.random_uniform([784, 10]), name="weights") 2402 bias = variables.VariableV1(array_ops.zeros([10]), name="bias") 2403 logit = nn_ops.relu( 2404 math_ops.matmul(image, weights) + bias, name="logits") 2405 nn_ops.softmax(logit, name="prediction") 2406 cost = nn_ops.softmax_cross_entropy_with_logits( 2407 labels=label, logits=logit, name="cost") 2408 adam.AdamOptimizer().minimize(cost, name="optimize") 2409 saver = saver_module.Saver() 2410 self.evaluate(variables.global_variables_initializer()) 2411 saver.save(sess, filename) 2412 2413 graph = ops_lib.Graph() 2414 with session.Session(graph=graph) as sess: 2415 new_saver = saver_module.import_meta_graph( 2416 filename + ".meta", graph=graph, import_scope="new_model") 2417 new_saver.restore(sess, filename) 2418 sess.run(["new_model/optimize"], { 2419 "new_model/image:0": np.random.random([1, 784]), 2420 "new_model/label:0": np.random.randint( 2421 10, size=[1, 10]) 2422 }) 2423 2424 def testImportIntoNamescopeWithoutVariables(self): 2425 # Save a simple graph that contains no variables into a checkpoint. 2426 test_dir = self._get_test_dir("no_vars_graph") 2427 filename = os.path.join(test_dir, "ckpt") 2428 graph_1 = ops_lib.Graph() 2429 with session.Session(graph=graph_1) as sess: 2430 constant_op.constant([1, 2, 3], name="x") 2431 constant_op.constant([1, 2, 3], name="y") 2432 saver = saver_module.Saver(allow_empty=True) 2433 saver.save(sess, filename) 2434 2435 # Create a fresh graph. 2436 graph_2 = ops_lib.Graph() 2437 with session.Session(graph=graph_2) as sess: 2438 # Restore the above checkpoint under scope "subgraph_1". 2439 new_saver_1 = saver_module.import_meta_graph( 2440 filename + ".meta", graph=graph_2, import_scope="subgraph_1") 2441 # There are no variables to restore, so import_meta_graph should not 2442 # return a Saver. 2443 self.assertIsNone(new_saver_1) 2444 2445 # Create a variable in graph_2 under scope "my_scope". 2446 variables.VariableV1(array_ops.zeros([10]), name="my_scope/my_var") 2447 self.evaluate(variables.global_variables_initializer()) 2448 # Restore the checkpoint into a different scope "subgraph_2". 2449 new_saver_2 = saver_module.import_meta_graph( 2450 filename + ".meta", graph=graph_2, import_scope="subgraph_2") 2451 # Because the variable does not live in scope "subgraph_2", 2452 # import_meta_graph should not attempt to restore the variable. So, 2453 # import_meta_graph still won't return a Saver instance. 2454 self.assertIsNone(new_saver_2) 2455 2456 # However, if we restore the checkpoint under scope "my_scope", 2457 # import_meta_graph will detect the variable and return a Saver for 2458 # restoring it. This should happen even when the variable does not 2459 # originate from graph_1. 2460 new_saver_3 = saver_module.import_meta_graph( 2461 filename + ".meta", graph=graph_2, import_scope="my_scope") 2462 self.assertIsInstance(new_saver_3, saver_module.Saver) 2463 2464 def testImportIntoImplicitNamescope(self): 2465 # Test that we can import a meta graph into an implicit namescope. 2466 test_dir = self._get_test_dir("import_into_namescope") 2467 filename = os.path.join(test_dir, "ckpt") 2468 # train.Saver is V1 only API. 2469 with ops_lib.Graph().as_default(): 2470 image = array_ops.placeholder(dtypes.float32, [None, 784], name="image") 2471 label = array_ops.placeholder(dtypes.float32, [None, 10], name="label") 2472 with session.Session() as sess: 2473 weights = variables.VariableV1( 2474 random_ops.random_uniform([784, 10]), name="weights") 2475 bias = variables.VariableV1(array_ops.zeros([10]), name="bias") 2476 logit = nn_ops.relu( 2477 math_ops.matmul(image, weights) + bias, name="logits") 2478 nn_ops.softmax(logit, name="prediction") 2479 cost = nn_ops.softmax_cross_entropy_with_logits( 2480 labels=label, logits=logit, name="cost") 2481 adam.AdamOptimizer().minimize(cost, name="optimize") 2482 saver = saver_module.Saver() 2483 self.evaluate(variables.global_variables_initializer()) 2484 saver.save(sess, filename) 2485 2486 graph = ops_lib.Graph() 2487 with session.Session(graph=graph) as sess: 2488 with ops_lib.name_scope("new_model"): 2489 new_saver = saver_module.import_meta_graph( 2490 filename + ".meta", graph=graph) 2491 2492 new_saver.restore(sess, filename) 2493 sess.run(["new_model/optimize"], { 2494 "new_model/image:0": np.random.random([1, 784]), 2495 "new_model/label:0": np.random.randint( 2496 10, size=[1, 10]) 2497 }) 2498 2499 def testClearDevicesOnImport(self): 2500 # Test that we import a graph without its devices and run successfully. 2501 with ops_lib.Graph().as_default(): 2502 with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"): 2503 image = array_ops.placeholder(dtypes.float32, [None, 784], name="image") 2504 label = array_ops.placeholder(dtypes.float32, [None, 10], name="label") 2505 weights = variables.VariableV1( 2506 random_ops.random_uniform([784, 10]), name="weights") 2507 bias = variables.VariableV1(array_ops.zeros([10]), name="bias") 2508 logit = nn_ops.relu(math_ops.matmul(image, weights) + bias) 2509 nn_ops.softmax(logit, name="prediction") 2510 cost = nn_ops.softmax_cross_entropy_with_logits(labels=label, 2511 logits=logit) 2512 adam.AdamOptimizer().minimize(cost, name="optimize") 2513 meta_graph_def = saver_module.export_meta_graph() 2514 2515 with session.Session(graph=ops_lib.Graph()) as sess: 2516 saver_module.import_meta_graph( 2517 meta_graph_def, clear_devices=False, import_scope="new_model") 2518 # Device refers to GPU, which is not available here. 2519 with self.assertRaises(errors_impl.InvalidArgumentError): 2520 self.evaluate(variables.global_variables_initializer()) 2521 2522 with session.Session(graph=ops_lib.Graph()) as sess: 2523 saver_module.import_meta_graph( 2524 meta_graph_def, clear_devices=True, import_scope="new_model") 2525 self.evaluate(variables.global_variables_initializer()) 2526 sess.run(["new_model/optimize"], { 2527 "new_model/image:0": np.random.random([1, 784]), 2528 "new_model/label:0": np.random.randint( 2529 10, size=[1, 10]) 2530 }) 2531 2532 def testClearDevicesOnExport(self): 2533 # Test that we export a graph without its devices and run successfully. 2534 with ops_lib.Graph().as_default(): 2535 with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"): 2536 image = array_ops.placeholder(dtypes.float32, [None, 784], name="image") 2537 label = array_ops.placeholder(dtypes.float32, [None, 10], name="label") 2538 weights = variables.VariableV1( 2539 random_ops.random_uniform([784, 10]), name="weights") 2540 bias = variables.VariableV1(array_ops.zeros([10]), name="bias") 2541 logit = nn_ops.relu(math_ops.matmul(image, weights) + bias) 2542 nn_ops.softmax(logit, name="prediction") 2543 cost = nn_ops.softmax_cross_entropy_with_logits(labels=label, 2544 logits=logit) 2545 adam.AdamOptimizer().minimize(cost, name="optimize") 2546 meta_graph_def = saver_module.export_meta_graph(clear_devices=True) 2547 graph_io.write_graph(meta_graph_def, self.get_temp_dir(), 2548 "meta_graph.pbtxt") 2549 2550 with session.Session(graph=ops_lib.Graph()) as sess: 2551 saver_module.import_meta_graph(meta_graph_def, import_scope="new_model") 2552 self.evaluate(variables.global_variables_initializer()) 2553 sess.run(["new_model/optimize"], { 2554 "new_model/image:0": np.random.random([1, 784]), 2555 "new_model/label:0": np.random.randint( 2556 10, size=[1, 10]) 2557 }) 2558 2559 def testPreserveDatasetAndFunctions(self): 2560 with ops_lib.Graph().as_default() as g: 2561 dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x) 2562 iterator = dataset_ops.make_one_shot_iterator(dataset) 2563 next_element = iterator.get_next() 2564 _ = array_ops.identity(next_element, name="output") 2565 2566 # Generate three MetaGraphDef protos using different code paths. 2567 meta_graph_def_simple = saver_module.export_meta_graph() 2568 meta_graph_def_devices_cleared = saver_module.export_meta_graph( 2569 clear_devices=True) 2570 meta_graph_def_from_graph_def = saver_module.export_meta_graph( 2571 clear_devices=True, graph_def=g.as_graph_def()) 2572 2573 for meta_graph_def in [meta_graph_def_simple, 2574 meta_graph_def_devices_cleared, 2575 meta_graph_def_from_graph_def]: 2576 with session.Session(graph=ops_lib.Graph()) as sess: 2577 saver_module.import_meta_graph(meta_graph_def, import_scope="new_model") 2578 self.evaluate(variables.global_variables_initializer()) 2579 for i in range(10): 2580 self.assertEqual(i * i, sess.run("new_model/output:0")) 2581 with self.assertRaises(errors.OutOfRangeError): 2582 sess.run("new_model/output:0") 2583 2584 2585class CheckpointReaderTest(test.TestCase): 2586 2587 _WRITE_VERSION = saver_pb2.SaverDef.V1 2588 2589 def testDebugString(self): 2590 # Builds a graph. 2591 v0 = variables.VariableV1( 2592 [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") 2593 v1 = variables.VariableV1( 2594 [[[1], [2]], [[3], [4]], [[5], [6]]], dtype=dtypes.float32, name="v1") 2595 init_all_op = variables.global_variables_initializer() 2596 save = saver_module.Saver( 2597 { 2598 "v0": v0, 2599 "v1": v1 2600 }, write_version=self._WRITE_VERSION) 2601 save_path = os.path.join(self.get_temp_dir(), 2602 "ckpt_for_debug_string" + str(self._WRITE_VERSION)) 2603 with self.cached_session() as sess: 2604 self.evaluate(init_all_op) 2605 # Saves a checkpoint. 2606 save.save(sess, save_path) 2607 2608 # Creates a reader. 2609 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 2610 # Verifies that the tensors exist. 2611 self.assertTrue(reader.has_tensor("v0")) 2612 self.assertTrue(reader.has_tensor("v1")) 2613 debug_string = reader.debug_string() 2614 # Verifies that debug string contains the right strings. 2615 self.assertTrue(compat.as_bytes("v0 (DT_FLOAT) [2,3]") in debug_string) 2616 self.assertTrue(compat.as_bytes("v1 (DT_FLOAT) [3,2,1]") in debug_string) 2617 # Verifies get_variable_to_shape_map() returns the correct information. 2618 var_map = reader.get_variable_to_shape_map() 2619 self.assertEqual([2, 3], var_map["v0"]) 2620 self.assertEqual([3, 2, 1], var_map["v1"]) 2621 # Verifies get_tensor() returns the tensor value. 2622 v0_tensor = reader.get_tensor("v0") 2623 v1_tensor = reader.get_tensor("v1") 2624 self.assertAllEqual(v0, v0_tensor) 2625 self.assertAllEqual(v1, v1_tensor) 2626 # Verifies get_tensor() fails for non-existent tensors. 2627 with self.assertRaisesRegex(errors.NotFoundError, 2628 "v3 not found in checkpoint"): 2629 reader.get_tensor("v3") 2630 2631 def testNonexistentPath(self): 2632 with self.assertRaisesRegex(errors.NotFoundError, 2633 "Unsuccessful TensorSliceReader"): 2634 py_checkpoint_reader.NewCheckpointReader("non-existent") 2635 2636 2637class CheckpointReaderForV2Test(CheckpointReaderTest): 2638 _WRITE_VERSION = saver_pb2.SaverDef.V2 2639 2640 2641class WriteGraphTest(test.TestCase): 2642 2643 def _get_test_dir(self, dirname): 2644 test_dir = os.path.join(self.get_temp_dir(), dirname) 2645 gfile.MakeDirs(test_dir) 2646 return test_dir 2647 2648 def testWriteGraph(self): 2649 test_dir = self._get_test_dir("write_graph_dir") 2650 variables.VariableV1( 2651 [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") 2652 path = graph_io.write_graph(ops_lib.get_default_graph(), 2653 os.path.join(test_dir, "l1"), "graph.pbtxt") 2654 truth = os.path.join(test_dir, "l1", "graph.pbtxt") 2655 self.assertEqual(path, truth) 2656 self.assertTrue(os.path.exists(path)) 2657 2658 def testRecursiveCreate(self): 2659 test_dir = self._get_test_dir("deep_dir") 2660 variables.VariableV1( 2661 [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") 2662 path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(), 2663 os.path.join(test_dir, "l1", "l2", "l3"), 2664 "graph.pbtxt") 2665 truth = os.path.join(test_dir, "l1", "l2", "l3", "graph.pbtxt") 2666 self.assertEqual(path, truth) 2667 self.assertTrue(os.path.exists(path)) 2668 2669 2670class ScopedGraphTest(test.TestCase): 2671 2672 def _get_test_dir(self, dirname): 2673 test_dir = os.path.join(self.get_temp_dir(), dirname) 2674 gfile.MakeDirs(test_dir) 2675 return test_dir 2676 2677 def _testScopedSave(self, test_dir, exported_filename, ckpt_filename): 2678 graph = ops_lib.Graph() 2679 with graph.as_default(): 2680 # Creates an inference graph. 2681 # Hidden 1 2682 images = constant_op.constant( 2683 1.2, dtypes.float32, shape=[100, 28], name="images") 2684 with ops_lib.name_scope("hidden1"): 2685 weights1 = variables.VariableV1( 2686 random_ops.truncated_normal( 2687 [28, 128], stddev=1.0 / math.sqrt(float(28))), 2688 name="weights") 2689 # The use of control_flow_ops.cond here is purely for adding test 2690 # coverage the save and restore of control flow context (which doesn't 2691 # make any sense here from a machine learning perspective). The typical 2692 # biases is a simple Variable without the conditions. 2693 biases1 = variables.VariableV1( 2694 control_flow_ops.cond( 2695 math_ops.less(random.random(), 0.5), 2696 lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])), 2697 name="biases") 2698 hidden1 = nn_ops.relu(math_ops.matmul(images, weights1) + biases1) 2699 2700 # Hidden 2 2701 with ops_lib.name_scope("hidden2"): 2702 weights2 = variables.VariableV1( 2703 random_ops.truncated_normal( 2704 [128, 32], stddev=1.0 / math.sqrt(float(128))), 2705 name="weights") 2706 2707 # The use of control_flow_ops.while_loop here is purely for adding test 2708 # coverage the save and restore of control flow context (which doesn't 2709 # make any sense here from a machine learning perspective). The typical 2710 # biases is a simple Variable without the conditions. 2711 def loop_cond(it, _): 2712 return it < 2 2713 2714 def loop_body(it, biases2): 2715 biases2 += constant_op.constant(0.1, shape=[32]) 2716 return it + 1, biases2 2717 2718 _, biases2 = control_flow_ops.while_loop(loop_cond, loop_body, [ 2719 constant_op.constant(0), variables.VariableV1(array_ops.zeros([32])) 2720 ]) 2721 hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2) 2722 # Linear 2723 with ops_lib.name_scope("softmax_linear"): 2724 weights3 = variables.VariableV1( 2725 random_ops.truncated_normal( 2726 [32, 10], stddev=1.0 / math.sqrt(float(32))), 2727 name="weights") 2728 biases3 = variables.VariableV1(array_ops.zeros([10]), name="biases") 2729 logits = math_ops.matmul(hidden2, weights3) + biases3 2730 ops_lib.add_to_collection("logits", logits) 2731 2732 # Adds user_defined proto in three formats: string, bytes and Any. 2733 # Any proto should just pass through. 2734 queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue") 2735 ops_lib.add_to_collection("user_defined_string_collection", 2736 str(queue_runner)) 2737 ops_lib.add_to_collection("user_defined_bytes_collection", 2738 queue_runner.SerializeToString()) 2739 any_buf = Any() 2740 any_buf.Pack(queue_runner) 2741 ops_lib.add_to_collection("user_defined_any_collection", any_buf) 2742 2743 _, var_list = meta_graph.export_scoped_meta_graph( 2744 filename=os.path.join(test_dir, exported_filename), 2745 graph=ops_lib.get_default_graph(), 2746 export_scope="hidden1") 2747 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys())) 2748 2749 with graph.as_default(), self.session() as sess: 2750 self.evaluate(variables.global_variables_initializer()) 2751 saver = saver_module.Saver(var_list=var_list, max_to_keep=1) 2752 saver.save(sess, os.path.join(test_dir, ckpt_filename), write_state=False) 2753 2754 def _testScopedRestore(self, test_dir, exported_filename, 2755 new_exported_filename, ckpt_filename): 2756 graph = ops_lib.Graph() 2757 # Create all the missing inputs. 2758 with graph.as_default(): 2759 new_image = constant_op.constant( 2760 1.2, dtypes.float32, shape=[100, 28], name="images") 2761 var_list = meta_graph.import_scoped_meta_graph( 2762 os.path.join(test_dir, exported_filename), 2763 graph=graph, 2764 input_map={"$unbound_inputs_images": new_image}, 2765 import_scope="new_hidden1") 2766 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys())) 2767 hidden1 = graph.as_graph_element("new_hidden1/Relu:0") 2768 weights1 = graph.as_graph_element("new_hidden1/weights:0") 2769 biases1 = graph.as_graph_element("new_hidden1/biases:0") 2770 2771 with graph.as_default(): 2772 # Hidden 2 2773 with ops_lib.name_scope("hidden2"): 2774 weights = variables.VariableV1( 2775 random_ops.truncated_normal( 2776 [128, 32], stddev=1.0 / math.sqrt(float(128))), 2777 name="weights") 2778 2779 # The use of control_flow_ops.while_loop here is purely for adding test 2780 # coverage the save and restore of control flow context (which doesn't 2781 # make any sense here from a machine learning perspective). The typical 2782 # biases is a simple Variable without the conditions. 2783 def loop_cond(it, _): 2784 return it < 2 2785 2786 def loop_body(it, biases): 2787 biases += constant_op.constant(0.1, shape=[32]) 2788 return it + 1, biases 2789 2790 _, biases = control_flow_ops.while_loop(loop_cond, loop_body, [ 2791 constant_op.constant(0), variables.VariableV1(array_ops.zeros([32])) 2792 ]) 2793 hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases) 2794 # Linear 2795 with ops_lib.name_scope("softmax_linear"): 2796 weights = variables.VariableV1( 2797 random_ops.truncated_normal( 2798 [32, 10], stddev=1.0 / math.sqrt(float(32))), 2799 name="weights") 2800 biases = variables.VariableV1(array_ops.zeros([10]), name="biases") 2801 logits = math_ops.matmul(hidden2, weights) + biases 2802 ops_lib.add_to_collection("logits", logits) 2803 2804 # The rest of the variables. 2805 rest_variables = list( 2806 set(variables.global_variables()) - set(var_list.keys())) 2807 init_rest_op = variables.variables_initializer(rest_variables) 2808 2809 with graph.as_default(), self.session() as sess: 2810 saver = saver_module.Saver(var_list=var_list, max_to_keep=1) 2811 saver.restore(sess, os.path.join(test_dir, ckpt_filename)) 2812 # Verify that we have restored weights1 and biases1. 2813 self.evaluate([weights1, biases1]) 2814 # Initialize the rest of the variables and run logits. 2815 self.evaluate(init_rest_op) 2816 self.evaluate(logits) 2817 2818 # Verifies that we can save the subgraph under "hidden1" and restore it 2819 # into "new_hidden1" in the new graph. 2820 def testScopedSaveAndRestore(self): 2821 test_dir = self._get_test_dir("scoped_export_import") 2822 ckpt_filename = "ckpt" 2823 self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename) 2824 self._testScopedRestore(test_dir, "exported_hidden1.pbtxt", 2825 "exported_new_hidden1.pbtxt", ckpt_filename) 2826 2827 # Verifies that we can copy the subgraph under "hidden1" and copy it 2828 # to different name scope in the same graph or different graph. 2829 def testCopyScopedGraph(self): 2830 test_dir = self._get_test_dir("scoped_copy") 2831 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 2832 graph1 = ops_lib.Graph() 2833 with graph1.as_default(): 2834 with ops_lib.name_scope("hidden1"): 2835 images = constant_op.constant( 2836 1.0, dtypes.float32, shape=[3, 2], name="images") 2837 weights1 = variables.VariableV1( 2838 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights") 2839 biases1 = variables.VariableV1([0.1] * 3, name="biases") 2840 nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu") 2841 2842 # Run the graph and save scoped checkpoint. 2843 with graph1.as_default(), self.session(graph=graph1) as sess: 2844 self.evaluate(variables.global_variables_initializer()) 2845 _, var_list_1 = meta_graph.export_scoped_meta_graph( 2846 export_scope="hidden1") 2847 saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1) 2848 saver.save(sess, saver0_ckpt, write_state=False) 2849 2850 expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3)) 2851 2852 # Verifies copy to the same graph with the same name fails. 2853 with graph1.as_default(): 2854 with self.assertRaisesWithPredicateMatch( 2855 ValueError, lambda e: "need to be different" in str(e)): 2856 meta_graph.copy_scoped_meta_graph( 2857 from_scope="hidden1", to_scope="hidden1") 2858 2859 # Verifies copy to the same graph. 2860 with graph1.as_default(): 2861 var_list_2 = meta_graph.copy_scoped_meta_graph( 2862 from_scope="hidden1", to_scope="hidden2") 2863 2864 with graph1.as_default(), self.session(graph=graph1) as sess: 2865 saver1 = saver_module.Saver(var_list=var_list_1, max_to_keep=1) 2866 saver1.restore(sess, saver0_ckpt) 2867 saver2 = saver_module.Saver(var_list=var_list_2, max_to_keep=1) 2868 saver2.restore(sess, saver0_ckpt) 2869 self.assertAllClose(expected, sess.run("hidden1/relu:0")) 2870 self.assertAllClose(expected, sess.run("hidden2/relu:0")) 2871 2872 # Verifies copy to different graph. 2873 graph2 = ops_lib.Graph() 2874 with graph2.as_default(): 2875 new_var_list_1 = meta_graph.copy_scoped_meta_graph( 2876 from_scope="hidden1", 2877 to_scope="new_hidden1", 2878 from_graph=graph1, 2879 to_graph=graph2) 2880 2881 with self.session() as sess: 2882 saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1) 2883 saver3.restore(sess, saver0_ckpt) 2884 self.assertAllClose(expected, sess.run("new_hidden1/relu:0")) 2885 2886 def testExportGraphDefWithScope(self): 2887 test_dir = self._get_test_dir("export_graph_def") 2888 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 2889 graph1 = ops_lib.Graph() 2890 with graph1.as_default(): 2891 with ops_lib.name_scope("hidden1"): 2892 images = constant_op.constant( 2893 1.0, dtypes.float32, shape=[3, 2], name="images") 2894 weights1 = variables.VariableV1( 2895 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights") 2896 biases1 = variables.VariableV1([0.1] * 3, name="biases") 2897 nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu") 2898 2899 # Run the graph and save scoped checkpoint. 2900 with self.session(graph=graph1) as sess: 2901 self.evaluate(variables.global_variables_initializer()) 2902 _, var_list_1 = meta_graph.export_scoped_meta_graph( 2903 graph_def=graph1.as_graph_def(), export_scope="hidden1") 2904 saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1) 2905 saver.save(sess, saver0_ckpt, write_state=False) 2906 2907 expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3)) 2908 2909 # Verifies that we can run successfully after restoring. 2910 graph2 = ops_lib.Graph() 2911 with graph2.as_default(): 2912 new_var_list_1 = meta_graph.copy_scoped_meta_graph( 2913 from_scope="hidden1", 2914 to_scope="new_hidden1", 2915 from_graph=graph1, 2916 to_graph=graph2) 2917 2918 with self.session(graph=graph2) as sess: 2919 saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1) 2920 saver3.restore(sess, saver0_ckpt) 2921 self.assertAllClose(expected, sess.run("new_hidden1/relu:0")) 2922 2923 def testSerializeSaverWithScope(self): 2924 test_dir = self._get_test_dir("export_graph_def") 2925 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt") 2926 saver2_ckpt = os.path.join(test_dir, "saver2.ckpt") 2927 graph = ops_lib.Graph() 2928 with graph.as_default(): 2929 with ops_lib.name_scope("hidden1"): 2930 variable1 = variables.VariableV1([1.0], name="variable1") 2931 saver1 = saver_module.Saver(var_list=[variable1]) 2932 graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver1) 2933 2934 with ops_lib.name_scope("hidden2"): 2935 variable2 = variables.VariableV1([2.0], name="variable2") 2936 saver2 = saver_module.Saver(var_list=[variable2], name="hidden2/") 2937 graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver2) 2938 2939 with self.session(graph=graph) as sess: 2940 self.evaluate(variables.global_variables_initializer()) 2941 saver1.save(sess, saver1_ckpt, write_state=False) 2942 saver2.save(sess, saver2_ckpt, write_state=False) 2943 2944 graph1 = ops_lib.Graph() 2945 with graph1.as_default(): 2946 var_dict1 = meta_graph.copy_scoped_meta_graph( 2947 from_scope="hidden1", 2948 to_scope="new_hidden1", 2949 from_graph=graph, 2950 to_graph=graph1) 2951 self.assertEqual(1, len(var_dict1)) 2952 2953 saver_list1 = graph1.get_collection(ops_lib.GraphKeys.SAVERS) 2954 self.assertEqual(1, len(saver_list1)) 2955 2956 with self.session(graph=graph1) as sess: 2957 saver_list1[0].restore(sess, saver1_ckpt) 2958 self.assertEqual(1.0, self.evaluate(var_dict1["variable1:0"])) 2959 2960 graph2 = ops_lib.Graph() 2961 with graph2.as_default(): 2962 var_dict2 = meta_graph.copy_scoped_meta_graph( 2963 from_scope="hidden2", 2964 to_scope="new_hidden2", 2965 from_graph=graph, 2966 to_graph=graph2) 2967 self.assertEqual(1, len(var_dict2)) 2968 2969 saver_list2 = graph2.get_collection(ops_lib.GraphKeys.SAVERS) 2970 self.assertEqual(1, len(saver_list2)) 2971 2972 with self.session(graph=graph2) as sess: 2973 saver_list2[0].restore(sess, saver2_ckpt) 2974 self.assertEqual(2.0, self.evaluate(var_dict2["variable2:0"])) 2975 2976 2977class _OwnsAVariableSimple(trackable_base.Trackable): 2978 """A Trackable object which can be saved using a tf.train.Saver.""" 2979 2980 def __init__(self): 2981 self.non_dep_variable = variable_scope.get_variable( 2982 name="non_dep_variable", initializer=6., use_resource=True) 2983 2984 def _gather_saveables_for_checkpoint(self): 2985 return {trackable_base.VARIABLE_VALUE_KEY: self.non_dep_variable} 2986 2987 # The Saver sorts by name before parsing, so we need a name property. 2988 @property 2989 def name(self): 2990 return self.non_dep_variable.name 2991 2992 2993class _MirroringSaveable( 2994 saver_module.BaseSaverBuilder.ResourceVariableSaveable): 2995 2996 def __init__(self, primary_variable, mirrored_variable, name): 2997 self._primary_variable = primary_variable 2998 self._mirrored_variable = mirrored_variable 2999 super(_MirroringSaveable, self).__init__( 3000 self._primary_variable, "", name) 3001 3002 def restore(self, restored_tensors, restored_shapes): 3003 """Restore the same value into both variables.""" 3004 tensor, = restored_tensors 3005 return control_flow_ops.group( 3006 self._primary_variable.assign(tensor), 3007 self._mirrored_variable.assign(tensor)) 3008 3009 3010class _OwnsMirroredVariables(trackable_base.Trackable): 3011 """A Trackable object which returns a more complex SaveableObject.""" 3012 3013 def __init__(self): 3014 self.non_dep_variable = variable_scope.get_variable( 3015 name="non_dep_variable", initializer=6., use_resource=True) 3016 self.mirrored = variable_scope.get_variable( 3017 name="mirrored", initializer=15., use_resource=True) 3018 3019 def _gather_saveables_for_checkpoint(self): 3020 def _saveable_factory(name=self.non_dep_variable.name): 3021 return _MirroringSaveable( 3022 primary_variable=self.non_dep_variable, 3023 mirrored_variable=self.mirrored, 3024 name=name) 3025 return {trackable_base.VARIABLE_VALUE_KEY: _saveable_factory} 3026 3027 # The Saver sorts by name before parsing, so we need a name property. 3028 @property 3029 def name(self): 3030 return self.non_dep_variable.name 3031 3032 3033class TrackableCompatibilityTests(test.TestCase): 3034 3035 # TODO(allenl): Track down python3 reference cycles in these tests. 3036 @test_util.run_in_graph_and_eager_modes 3037 def testNotSaveableButIsTrackable(self): 3038 v = _OwnsAVariableSimple() 3039 test_dir = self.get_temp_dir() 3040 prefix = os.path.join(test_dir, "ckpt") 3041 for saver in (saver_module.Saver(var_list=[v]), 3042 saver_module.Saver(var_list={"v": v})): 3043 with self.cached_session() as sess: 3044 self.evaluate(v.non_dep_variable.assign(42.)) 3045 save_path = saver.save(sess, prefix) 3046 self.evaluate(v.non_dep_variable.assign(43.)) 3047 saver.restore(sess, save_path) 3048 self.assertEqual(42., self.evaluate(v.non_dep_variable)) 3049 3050 @test_util.run_in_graph_and_eager_modes 3051 def testMoreComplexSaveableReturned(self): 3052 v = _OwnsMirroredVariables() 3053 test_dir = self.get_temp_dir() 3054 prefix = os.path.join(test_dir, "ckpt") 3055 self.evaluate(v.non_dep_variable.assign(42.)) 3056 for saver in (saver_module.Saver(var_list=[v]), 3057 saver_module.Saver(var_list={"v": v})): 3058 with self.cached_session() as sess: 3059 save_path = saver.save(sess, prefix) 3060 self.evaluate(v.non_dep_variable.assign(43.)) 3061 self.evaluate(v.mirrored.assign(44.)) 3062 saver.restore(sess, save_path) 3063 self.assertEqual(42., self.evaluate(v.non_dep_variable)) 3064 self.assertEqual(42., self.evaluate(v.mirrored)) 3065 3066 def testSingleTensorEvaluation(self): 3067 3068 class _CountingSaveable(saver_module.BaseSaverBuilder.SaveableObject): 3069 3070 def __init__(self, name): 3071 self.eval_count = 0 3072 def _tensor(): 3073 self.eval_count += 1 3074 return constant_op.constant([1.]) 3075 dummy_op = constant_op.constant([2.]) 3076 super(_CountingSaveable, self).__init__( 3077 dummy_op, 3078 [saver_module.BaseSaverBuilder.SaveSpec( 3079 _tensor, "", name, dtype=dummy_op.dtype, 3080 device=dummy_op.device)], 3081 name) 3082 3083 def restore(self, restored_tensors, restored_shapes): 3084 """Restore the same value into both variables.""" 3085 pass 3086 3087 with context.eager_mode(): 3088 v = _CountingSaveable("foo") 3089 saver = saver_module.Saver(var_list=[v]) 3090 test_dir = self.get_temp_dir() 3091 prefix = os.path.join(test_dir, "ckpt") 3092 with self.cached_session() as sess: 3093 save_path = saver.save(sess, prefix) 3094 self.assertEqual(1, v.eval_count) 3095 saver.restore(sess, save_path) 3096 self.assertEqual(1, v.eval_count) 3097 3098 def testVariableNotFoundErrorRaised(self): 3099 # Restore does some tricky exception handling to figure out if it should 3100 # load an object-based checkpoint. Tests that the exception handling isn't 3101 # too broad. 3102 checkpoint_directory = self.get_temp_dir() 3103 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 3104 3105 a = resource_variable_ops.ResourceVariable(1., name="a") 3106 b = resource_variable_ops.ResourceVariable(1., name="b") 3107 a_saver = saver_module.Saver([a]) 3108 b_saver = saver_module.Saver([b]) 3109 with self.cached_session() as sess: 3110 self.evaluate(a.initializer) 3111 save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix) 3112 with self.assertRaisesRegex(errors.NotFoundError, 3113 "Key b not found in checkpoint"): 3114 b_saver.restore(sess=sess, save_path=save_path) 3115 3116 with self.assertRaises(errors.NotFoundError) as cs: 3117 b_saver.restore(sess=sess, save_path=save_path) 3118 3119 # Make sure we don't have a confusing "During handling of the above 3120 # exception" block in Python 3. 3121 self.assertNotIn("NewCheckpointReader", cs.exception.message) 3122 3123 @test_util.run_v1_only("train.Saver is V1 only API.") 3124 def testGraphChangedForRestoreErrorRaised(self): 3125 checkpoint_directory = self.get_temp_dir() 3126 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 3127 3128 with ops_lib.Graph().as_default() as g: 3129 a = variables.VariableV1(1., name="a") 3130 a_saver = saver_module.Saver([a]) 3131 3132 with self.session(graph=g) as sess: 3133 self.evaluate(a.initializer) 3134 save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix) 3135 3136 with ops_lib.Graph().as_default() as g: 3137 a = variables.VariableV1([1.], name="a") 3138 a_saver = saver_module.Saver([a]) 3139 with self.session(graph=g) as sess: 3140 with self.assertRaisesRegex( 3141 errors.InvalidArgumentError, 3142 "a mismatch between the current graph and the graph"): 3143 a_saver.restore(sess=sess, save_path=save_path) 3144 3145 3146if __name__ == "__main__": 3147 test.main() 3148