1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tools to work with checkpoints.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from collections import abc 22import time 23 24import six 25 26from tensorflow.python.distribute import distribution_strategy_context 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import io_ops 29from tensorflow.python.ops import resource_variable_ops 30from tensorflow.python.ops import variable_scope as vs 31from tensorflow.python.ops import variables 32from tensorflow.python.platform import gfile 33from tensorflow.python.platform import tf_logging as logging 34from tensorflow.python.training import checkpoint_management 35from tensorflow.python.training import py_checkpoint_reader 36from tensorflow.python.training.saving import saveable_object_util 37from tensorflow.python.util.tf_export import tf_export 38 39 40__all__ = [ 41 "load_checkpoint", "load_variable", "list_variables", 42 "checkpoints_iterator", "init_from_checkpoint" 43] 44 45 46@tf_export("train.load_checkpoint") 47def load_checkpoint(ckpt_dir_or_file): 48 """Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`. 49 50 If `ckpt_dir_or_file` resolves to a directory with multiple checkpoints, 51 reader for the latest checkpoint is returned. 52 53 Args: 54 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint 55 file. 56 57 Returns: 58 `CheckpointReader` object. 59 60 Raises: 61 ValueError: If `ckpt_dir_or_file` resolves to a directory with no 62 checkpoints. 63 """ 64 filename = _get_checkpoint_filename(ckpt_dir_or_file) 65 if filename is None: 66 raise ValueError("Couldn't find 'checkpoint' file or checkpoints in " 67 "given directory %s" % ckpt_dir_or_file) 68 return py_checkpoint_reader.NewCheckpointReader(filename) 69 70 71@tf_export("train.load_variable") 72def load_variable(ckpt_dir_or_file, name): 73 """Returns the tensor value of the given variable in the checkpoint. 74 75 Args: 76 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 77 name: Name of the variable to return. 78 79 Returns: 80 A numpy `ndarray` with a copy of the value of this variable. 81 """ 82 # TODO(b/29227106): Fix this in the right place and remove this. 83 if name.endswith(":0"): 84 name = name[:-2] 85 reader = load_checkpoint(ckpt_dir_or_file) 86 return reader.get_tensor(name) 87 88 89@tf_export("train.list_variables") 90def list_variables(ckpt_dir_or_file): 91 """Lists the checkpoint keys and shapes of variables in a checkpoint. 92 93 Checkpoint keys are paths in a checkpoint graph. 94 95 Example usage: 96 97 ```python 98 import tensorflow as tf 99 import os 100 ckpt_directory = "/tmp/training_checkpoints/ckpt" 101 ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model) 102 manager = tf.train.CheckpointManager(ckpt, ckpt_directory, max_to_keep=3) 103 train_and_checkpoint(model, manager) 104 tf.train.list_variables(manager.latest_checkpoint) 105 ``` 106 107 Args: 108 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 109 110 Returns: 111 List of tuples `(key, shape)`. 112 """ 113 reader = load_checkpoint(ckpt_dir_or_file) 114 variable_map = reader.get_variable_to_shape_map() 115 names = sorted(variable_map.keys()) 116 result = [] 117 for name in names: 118 result.append((name, variable_map[name])) 119 return result 120 121 122def wait_for_new_checkpoint(checkpoint_dir, 123 last_checkpoint=None, 124 seconds_to_sleep=1, 125 timeout=None): 126 """Waits until a new checkpoint file is found. 127 128 Args: 129 checkpoint_dir: The directory in which checkpoints are saved. 130 last_checkpoint: The last checkpoint path used or `None` if we're expecting 131 a checkpoint for the first time. 132 seconds_to_sleep: The number of seconds to sleep for before looking for a 133 new checkpoint. 134 timeout: The maximum number of seconds to wait. If left as `None`, then the 135 process will wait indefinitely. 136 137 Returns: 138 a new checkpoint path, or None if the timeout was reached. 139 """ 140 logging.info("Waiting for new checkpoint at %s", checkpoint_dir) 141 stop_time = time.time() + timeout if timeout is not None else None 142 while True: 143 checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir) 144 if checkpoint_path is None or checkpoint_path == last_checkpoint: 145 if stop_time is not None and time.time() + seconds_to_sleep > stop_time: 146 return None 147 time.sleep(seconds_to_sleep) 148 else: 149 logging.info("Found new checkpoint at %s", checkpoint_path) 150 return checkpoint_path 151 152 153@tf_export("train.checkpoints_iterator") 154def checkpoints_iterator(checkpoint_dir, 155 min_interval_secs=0, 156 timeout=None, 157 timeout_fn=None): 158 """Continuously yield new checkpoint files as they appear. 159 160 The iterator only checks for new checkpoints when control flow has been 161 reverted to it. This means it can miss checkpoints if your code takes longer 162 to run between iterations than `min_interval_secs` or the interval at which 163 new checkpoints are written. 164 165 The `timeout` argument is the maximum number of seconds to block waiting for 166 a new checkpoint. It is used in combination with the `timeout_fn` as 167 follows: 168 169 * If the timeout expires and no `timeout_fn` was specified, the iterator 170 stops yielding. 171 * If a `timeout_fn` was specified, that function is called and if it returns 172 a true boolean value the iterator stops yielding. 173 * If the function returns a false boolean value then the iterator resumes the 174 wait for new checkpoints. At this point the timeout logic applies again. 175 176 This behavior gives control to callers on what to do if checkpoints do not 177 come fast enough or stop being generated. For example, if callers have a way 178 to detect that the training has stopped and know that no new checkpoints 179 will be generated, they can provide a `timeout_fn` that returns `True` when 180 the training has stopped. If they know that the training is still going on 181 they return `False` instead. 182 183 Args: 184 checkpoint_dir: The directory in which checkpoints are saved. 185 min_interval_secs: The minimum number of seconds between yielding 186 checkpoints. 187 timeout: The maximum number of seconds to wait between checkpoints. If left 188 as `None`, then the process will wait indefinitely. 189 timeout_fn: Optional function to call after a timeout. If the function 190 returns True, then it means that no new checkpoints will be generated and 191 the iterator will exit. The function is called with no arguments. 192 193 Yields: 194 String paths to latest checkpoint files as they arrive. 195 """ 196 checkpoint_path = None 197 while True: 198 new_checkpoint_path = wait_for_new_checkpoint( 199 checkpoint_dir, checkpoint_path, timeout=timeout) 200 if new_checkpoint_path is None: 201 if not timeout_fn: 202 # timed out 203 logging.info("Timed-out waiting for a checkpoint.") 204 return 205 if timeout_fn(): 206 # The timeout_fn indicated that we are truly done. 207 return 208 else: 209 # The timeout_fn indicated that more checkpoints may come. 210 continue 211 start = time.time() 212 checkpoint_path = new_checkpoint_path 213 yield checkpoint_path 214 time_to_next_eval = start + min_interval_secs - time.time() 215 if time_to_next_eval > 0: 216 time.sleep(time_to_next_eval) 217 218 219@tf_export(v1=["train.init_from_checkpoint"]) 220def init_from_checkpoint(ckpt_dir_or_file, assignment_map): 221 """Replaces `tf.Variable` initializers so they load from a checkpoint file. 222 223 @compatibility(TF2) 224 `tf.compat.v1.train.init_from_checkpoint` is not recommended for restoring 225 variable values in TF2. 226 227 To restore checkpoints in TF2, please use 228 `tf.keras.Model.load_weights` or `tf.train.Checkpoint.restore`. These APIs use 229 use an [object-based method of checkpointing] 230 (https://www.tensorflow.org/guide/checkpoint#loading_mechanics), while 231 `tf.compat.v1.init_from_checkpoint` relies on a more-fragile variable-name 232 based method of checkpointing. There is no object-based equivalent of 233 `init_from_checkpoint` in TF2. 234 235 Please re-write your checkpoints immediately using the object-based APIs, 236 see [migration guide] 237 (https://www.tensorflow.org/guide/migrate#checkpoint_compatibility) for more 238 details. 239 240 You can load a name-based checkpoint written by `tf.compat.v1.train.Saver` 241 using `tf.train.Checkpoint.restore` or `tf.keras.Model.load_weights`. However, 242 you may have to change the names of the variables in your model to match the 243 variable names in the name-based checkpoint, which can be viewed with 244 `tf.train.list_variables(path)`. 245 246 Another option is to create an `assignment_map` that maps the name of the 247 variables in the name-based checkpoint to the variables in your model, eg: 248 ``` 249 { 250 'sequential/dense/bias': model.variables[0], 251 'sequential/dense/kernel': model.variables[1] 252 } 253 ``` 254 and use `tf.compat.v1.train.init_from_checkpoint(path, assignment_map)` to 255 restore the name-based checkpoint. 256 257 After restoring, re-encode your checkpoint using `tf.train.Checkpoint.save` 258 or `tf.keras.Model.save_weights`. 259 260 @end_compatibility 261 262 Values are not loaded immediately, but when the initializer is run 263 (typically by running a `tf.compat.v1.global_variables_initializer` op). 264 265 Note: This overrides default initialization ops of specified variables and 266 redefines dtype. 267 268 Assignment map supports following syntax: 269 270 * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in 271 current `scope_name` from `checkpoint_scope_name` with matching tensor 272 names. 273 * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` - 274 will initialize `scope_name/variable_name` variable 275 from `checkpoint_scope_name/some_other_variable`. 276 * `'scope_variable_name': variable` - will initialize given `tf.Variable` 277 object with tensor 'scope_variable_name' from the checkpoint. 278 * `'scope_variable_name': list(variable)` - will initialize list of 279 partitioned variables with tensor 'scope_variable_name' from the checkpoint. 280 * `'/': 'scope_name/'` - will load all variables in current `scope_name` from 281 checkpoint's root (e.g. no scope). 282 283 Supports loading into partitioned variables, which are represented as 284 `'<variable>/part_<part #>'`. 285 286 Assignment map can be a dict, or a list of pairs. The latter is 287 necessary to initialize multiple variables in the current graph from 288 the same variable in the checkpoint. 289 290 Example: 291 292 ```python 293 294 # Say, '/tmp/model.ckpt' has the following tensors: 295 # -- name='old_scope_1/var1', shape=[20, 2] 296 # -- name='old_scope_1/var2', shape=[50, 4] 297 # -- name='old_scope_2/var3', shape=[100, 100] 298 299 # Create new model's variables 300 with tf.compat.v1.variable_scope('new_scope_1'): 301 var1 = tf.compat.v1.get_variable('var1', shape=[20, 2], 302 initializer=tf.compat.v1.zeros_initializer()) 303 with tf.compat.v1.variable_scope('new_scope_2'): 304 var2 = tf.compat.v1.get_variable('var2', shape=[50, 4], 305 initializer=tf.compat.v1.zeros_initializer()) 306 # Partition into 5 variables along the first axis. 307 var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100], 308 initializer=tf.compat.v1.zeros_initializer(), 309 partitioner=lambda shape, dtype: [5, 1]) 310 311 # Initialize all variables in `new_scope_1` from `old_scope_1`. 312 init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1/'}) 313 314 # Use names to specify which variables to initialize from checkpoint. 315 init_from_checkpoint('/tmp/model.ckpt', 316 {'old_scope_1/var1': 'new_scope_1/var1', 317 'old_scope_1/var2': 'new_scope_2/var2'}) 318 319 # Or use tf.Variable objects to identify what to initialize. 320 init_from_checkpoint('/tmp/model.ckpt', 321 {'old_scope_1/var1': var1, 322 'old_scope_1/var2': var2}) 323 324 # Initialize partitioned variables using variable's name 325 init_from_checkpoint('/tmp/model.ckpt', 326 {'old_scope_2/var3': 'new_scope_2/var3'}) 327 328 # Or specify the list of tf.Variable objects. 329 init_from_checkpoint('/tmp/model.ckpt', 330 {'old_scope_2/var3': var3._get_variable_list()}) 331 332 ``` 333 334 Args: 335 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 336 assignment_map: Dict, or a list of key-value pairs, where keys are names 337 of the variables in the checkpoint and values are current variables or 338 names of current variables (in default graph). 339 340 Raises: 341 ValueError: If missing variables in current graph, or if missing 342 checkpoints or tensors in checkpoints. 343 344 """ 345 init_from_checkpoint_fn = lambda _: _init_from_checkpoint( 346 ckpt_dir_or_file, assignment_map) 347 if distribution_strategy_context.get_cross_replica_context(): 348 init_from_checkpoint_fn(None) 349 else: 350 distribution_strategy_context.get_replica_context().merge_call( 351 init_from_checkpoint_fn) 352 353 354def _init_from_checkpoint(ckpt_dir_or_file, assignment_map): 355 """See `init_from_checkpoint` for documentation.""" 356 ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file) 357 reader = load_checkpoint(ckpt_dir_or_file) 358 variable_map = reader.get_variable_to_shape_map() 359 if isinstance(assignment_map, abc.Mapping): 360 assignment_map = six.iteritems(assignment_map) 361 362 # We only want to sort by tensor names. 363 sort_key = lambda pair: pair[0] 364 365 for tensor_name_in_ckpt, current_var_or_name in sorted( 366 assignment_map, key=sort_key): 367 var = None 368 # Check if this is Variable object or list of Variable objects (in case of 369 # partitioned variables). 370 if _is_variable(current_var_or_name) or ( 371 isinstance(current_var_or_name, list) 372 and all(_is_variable(v) for v in current_var_or_name)): 373 var = current_var_or_name 374 else: 375 store_vars = vs._get_default_variable_store()._vars # pylint:disable=protected-access 376 # Check if this variable is in var_store. 377 var = store_vars.get(current_var_or_name, None) 378 # Also check if variable is partitioned as list. 379 if var is None: 380 var = _collect_partitioned_variable(current_var_or_name, store_vars) 381 if var is not None: 382 # If 1 to 1 mapping was provided, find variable in the checkpoint. 383 if tensor_name_in_ckpt not in variable_map: 384 raise ValueError("Tensor %s is not found in %s checkpoint %s" % ( 385 tensor_name_in_ckpt, ckpt_dir_or_file, variable_map 386 )) 387 if _is_variable(var): 388 # Additional at-call-time checks. 389 if not var.get_shape().is_compatible_with( 390 variable_map[tensor_name_in_ckpt]): 391 raise ValueError( 392 "Shape of variable %s (%s) doesn't match with shape of " 393 "tensor %s (%s) from checkpoint reader." % ( 394 var.name, str(var.get_shape()), 395 tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt]) 396 )) 397 var_name = var.name 398 else: 399 var_name = ",".join(v.name for v in var) 400 _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt) 401 logging.debug("Initialize variable %s from checkpoint %s with %s", 402 var_name, ckpt_dir_or_file, tensor_name_in_ckpt) 403 else: 404 scopes = "" 405 # TODO(vihanjain): Support list of 'current_var_or_name' here. 406 if "/" in current_var_or_name: 407 scopes = current_var_or_name[:current_var_or_name.rindex("/")] 408 if not tensor_name_in_ckpt.endswith("/"): 409 raise ValueError( 410 "Assignment map with scope only name {} should map to scope only " 411 "{}. Should be 'scope/': 'other_scope/'.".format( 412 scopes, tensor_name_in_ckpt)) 413 # If scope to scope mapping was provided, find all variables in the scope 414 # and create variable to variable mapping. 415 scope_variables = set() 416 for var_name in store_vars: 417 if not scopes or var_name.startswith(scopes + "/"): 418 # Consume /part_ if partitioned variable. 419 if "/part_" in var_name: 420 var_name = var_name[:var_name.index("/part_")] 421 scope_variables.add(var_name) 422 for var_name in sorted(scope_variables): 423 # Lookup name with specified prefix and suffix from current variable. 424 # If tensor_name given is '/' (root), don't use it for full name. 425 full_tensor_name = var_name[len(scopes):] 426 if current_var_or_name != "/": 427 full_tensor_name = full_tensor_name[1:] 428 if tensor_name_in_ckpt != "/": 429 full_tensor_name = tensor_name_in_ckpt + full_tensor_name 430 # Remove trailing '/', if any, in the full_tensor_name 431 if full_tensor_name.endswith("/"): 432 full_tensor_name = full_tensor_name[:-1] 433 if full_tensor_name not in variable_map: 434 raise ValueError( 435 "Tensor %s (%s in %s) is not found in %s checkpoint" % ( 436 full_tensor_name, var_name[len(scopes) + 1:], 437 tensor_name_in_ckpt, ckpt_dir_or_file 438 )) 439 var = store_vars.get(var_name, None) 440 if var is None: 441 var = _collect_partitioned_variable(var_name, store_vars) 442 _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name) 443 logging.debug("Initialize variable %s from checkpoint %s with %s", 444 var_name, ckpt_dir_or_file, full_tensor_name) 445 446 447def _get_checkpoint_filename(ckpt_dir_or_file): 448 """Returns checkpoint filename given directory or specific checkpoint file.""" 449 if gfile.IsDirectory(ckpt_dir_or_file): 450 return checkpoint_management.latest_checkpoint(ckpt_dir_or_file) 451 return ckpt_dir_or_file 452 453 454def _set_checkpoint_initializer(variable, 455 ckpt_file, 456 tensor_name, 457 slice_spec, 458 name="checkpoint_initializer"): 459 """Overrides given variable's initialization op. 460 461 Sets variable initializer to assign op that initializes variable from tensor's 462 value in the checkpoint. 463 464 Args: 465 variable: `tf.Variable` object. 466 ckpt_file: string, full path of the checkpoint. 467 tensor_name: Name of the tensor to load from the checkpoint. 468 slice_spec: Slice specification for loading partitioned tensors. 469 name: Name of the operation. 470 """ 471 base_type = variable.dtype.base_dtype 472 # Do not colocate with variable since RestoreV2 op only runs on CPU and 473 # colocation will force variable (and other ops that colocate with variable) 474 # to be on CPU as well. It is okay to place the variable's initializer op on 475 # CPU since it will only be run once at the start. 476 with ops.device(variable.device), ops.device("/cpu:0"): 477 restore_op = io_ops.restore_v2( 478 ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] 479 480 names_to_saveables = saveable_object_util.op_list_to_dict([variable]) 481 saveable_objects = [] 482 for name, op in names_to_saveables.items(): 483 for s in saveable_object_util.saveable_objects_for_op(op, name): 484 saveable_objects.append(s) 485 486 assert len(saveable_objects) == 1 # Should be only one variable. 487 init_op = saveable_objects[0].restore([restore_op], restored_shapes=None) 488 489 # pylint:disable=protected-access 490 variable._initializer_op = init_op 491 restore_op.set_shape(variable.shape) 492 variable._initial_value = restore_op 493 # pylint:enable=protected-access 494 495 496def _set_variable_or_list_initializer(variable_or_list, ckpt_file, 497 tensor_name): 498 """Overrides initialization op of given variable or list of variables. 499 500 Calls `_set_checkpoint_initializer` for each variable in the given list of 501 variables. 502 503 Args: 504 variable_or_list: `tf.Variable` object or a list of `tf.Variable` objects. 505 ckpt_file: string, full path of the checkpoint. 506 tensor_name: Name of the tensor to load from the checkpoint. 507 508 Raises: 509 ValueError: if all objects in `variable_or_list` are not partitions of the 510 same large variable. 511 """ 512 if isinstance(variable_or_list, (list, tuple)): 513 # A set of slices. 514 slice_name = None 515 for v in variable_or_list: 516 slice_info = v._save_slice_info # pylint:disable=protected-access 517 if slice_name is None: 518 slice_name = slice_info.full_name 519 elif slice_name != slice_info.full_name: 520 raise ValueError("Slices must all be from the same tensor: %s != %s" % 521 (slice_name, slice_info.full_name)) 522 _set_checkpoint_initializer(v, ckpt_file, tensor_name, slice_info.spec) 523 else: 524 _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "") 525 526 527def _is_variable(x): 528 return (isinstance(x, variables.Variable) or 529 resource_variable_ops.is_resource_variable(x)) 530 531 532def _collect_partitioned_variable(name, all_vars): 533 """Returns list of `tf.Variable` that comprise the partitioned variable.""" 534 if name + "/part_0" in all_vars: 535 var = [] 536 i = 0 537 while name + "/part_%d" % i in all_vars: 538 var.append(all_vars[name + "/part_%d" % i]) 539 i += 1 540 return var 541 return None 542