1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tools to work with checkpoints.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import time 22 23import six 24 25from tensorflow.python.distribute import distribution_strategy_context 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import io_ops 28from tensorflow.python.ops import resource_variable_ops 29from tensorflow.python.ops import variable_scope as vs 30from tensorflow.python.ops import variables 31from tensorflow.python.platform import gfile 32from tensorflow.python.platform import tf_logging as logging 33from tensorflow.python.training import checkpoint_management 34from tensorflow.python.training import py_checkpoint_reader 35from tensorflow.python.training.saving import saveable_object_util 36from tensorflow.python.util.tf_export import tf_export 37 38 39__all__ = [ 40 "load_checkpoint", "load_variable", "list_variables", 41 "checkpoints_iterator", "init_from_checkpoint" 42] 43 44 45@tf_export("train.load_checkpoint") 46def load_checkpoint(ckpt_dir_or_file): 47 """Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`. 48 49 If `ckpt_dir_or_file` resolves to a directory with multiple checkpoints, 50 reader for the latest checkpoint is returned. 51 52 Args: 53 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint 54 file. 55 56 Returns: 57 `CheckpointReader` object. 58 59 Raises: 60 ValueError: If `ckpt_dir_or_file` resolves to a directory with no 61 checkpoints. 62 """ 63 filename = _get_checkpoint_filename(ckpt_dir_or_file) 64 if filename is None: 65 raise ValueError("Couldn't find 'checkpoint' file or checkpoints in " 66 "given directory %s" % ckpt_dir_or_file) 67 return py_checkpoint_reader.NewCheckpointReader(filename) 68 69 70@tf_export("train.load_variable") 71def load_variable(ckpt_dir_or_file, name): 72 """Returns the tensor value of the given variable in the checkpoint. 73 74 Args: 75 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 76 name: Name of the variable to return. 77 78 Returns: 79 A numpy `ndarray` with a copy of the value of this variable. 80 """ 81 # TODO(b/29227106): Fix this in the right place and remove this. 82 if name.endswith(":0"): 83 name = name[:-2] 84 reader = load_checkpoint(ckpt_dir_or_file) 85 return reader.get_tensor(name) 86 87 88@tf_export("train.list_variables") 89def list_variables(ckpt_dir_or_file): 90 """Lists the checkpoint keys and shapes of variables in a checkpoint. 91 92 Checkpoint keys are paths in a checkpoint graph. 93 94 Example usage: 95 96 ```python 97 import tensorflow as tf 98 import os 99 ckpt_directory = "/tmp/training_checkpoints/ckpt" 100 ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model) 101 manager = tf.train.CheckpointManager(ckpt, ckpt_directory, max_to_keep=3) 102 train_and_checkpoint(model, manager) 103 tf.train.list_variables(manager.latest_checkpoint) 104 ``` 105 106 Args: 107 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 108 109 Returns: 110 List of tuples `(key, shape)`. 111 """ 112 reader = load_checkpoint(ckpt_dir_or_file) 113 variable_map = reader.get_variable_to_shape_map() 114 names = sorted(variable_map.keys()) 115 result = [] 116 for name in names: 117 result.append((name, variable_map[name])) 118 return result 119 120 121def wait_for_new_checkpoint(checkpoint_dir, 122 last_checkpoint=None, 123 seconds_to_sleep=1, 124 timeout=None): 125 """Waits until a new checkpoint file is found. 126 127 Args: 128 checkpoint_dir: The directory in which checkpoints are saved. 129 last_checkpoint: The last checkpoint path used or `None` if we're expecting 130 a checkpoint for the first time. 131 seconds_to_sleep: The number of seconds to sleep for before looking for a 132 new checkpoint. 133 timeout: The maximum number of seconds to wait. If left as `None`, then the 134 process will wait indefinitely. 135 136 Returns: 137 a new checkpoint path, or None if the timeout was reached. 138 """ 139 logging.info("Waiting for new checkpoint at %s", checkpoint_dir) 140 stop_time = time.time() + timeout if timeout is not None else None 141 while True: 142 checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir) 143 if checkpoint_path is None or checkpoint_path == last_checkpoint: 144 if stop_time is not None and time.time() + seconds_to_sleep > stop_time: 145 return None 146 time.sleep(seconds_to_sleep) 147 else: 148 logging.info("Found new checkpoint at %s", checkpoint_path) 149 return checkpoint_path 150 151 152@tf_export("train.checkpoints_iterator") 153def checkpoints_iterator(checkpoint_dir, 154 min_interval_secs=0, 155 timeout=None, 156 timeout_fn=None): 157 """Continuously yield new checkpoint files as they appear. 158 159 The iterator only checks for new checkpoints when control flow has been 160 reverted to it. This means it can miss checkpoints if your code takes longer 161 to run between iterations than `min_interval_secs` or the interval at which 162 new checkpoints are written. 163 164 The `timeout` argument is the maximum number of seconds to block waiting for 165 a new checkpoint. It is used in combination with the `timeout_fn` as 166 follows: 167 168 * If the timeout expires and no `timeout_fn` was specified, the iterator 169 stops yielding. 170 * If a `timeout_fn` was specified, that function is called and if it returns 171 a true boolean value the iterator stops yielding. 172 * If the function returns a false boolean value then the iterator resumes the 173 wait for new checkpoints. At this point the timeout logic applies again. 174 175 This behavior gives control to callers on what to do if checkpoints do not 176 come fast enough or stop being generated. For example, if callers have a way 177 to detect that the training has stopped and know that no new checkpoints 178 will be generated, they can provide a `timeout_fn` that returns `True` when 179 the training has stopped. If they know that the training is still going on 180 they return `False` instead. 181 182 Args: 183 checkpoint_dir: The directory in which checkpoints are saved. 184 min_interval_secs: The minimum number of seconds between yielding 185 checkpoints. 186 timeout: The maximum number of seconds to wait between checkpoints. If left 187 as `None`, then the process will wait indefinitely. 188 timeout_fn: Optional function to call after a timeout. If the function 189 returns True, then it means that no new checkpoints will be generated and 190 the iterator will exit. The function is called with no arguments. 191 192 Yields: 193 String paths to latest checkpoint files as they arrive. 194 """ 195 checkpoint_path = None 196 while True: 197 new_checkpoint_path = wait_for_new_checkpoint( 198 checkpoint_dir, checkpoint_path, timeout=timeout) 199 if new_checkpoint_path is None: 200 if not timeout_fn: 201 # timed out 202 logging.info("Timed-out waiting for a checkpoint.") 203 return 204 if timeout_fn(): 205 # The timeout_fn indicated that we are truly done. 206 return 207 else: 208 # The timeout_fn indicated that more checkpoints may come. 209 continue 210 start = time.time() 211 checkpoint_path = new_checkpoint_path 212 yield checkpoint_path 213 time_to_next_eval = start + min_interval_secs - time.time() 214 if time_to_next_eval > 0: 215 time.sleep(time_to_next_eval) 216 217 218@tf_export(v1=["train.init_from_checkpoint"]) 219def init_from_checkpoint(ckpt_dir_or_file, assignment_map): 220 """Replaces `tf.Variable` initializers so they load from a checkpoint file. 221 222 Values are not loaded immediately, but when the initializer is run 223 (typically by running a `tf.compat.v1.global_variables_initializer` op). 224 225 Note: This overrides default initialization ops of specified variables and 226 redefines dtype. 227 228 Assignment map supports following syntax: 229 230 * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in 231 current `scope_name` from `checkpoint_scope_name` with matching tensor 232 names. 233 * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` - 234 will initialize `scope_name/variable_name` variable 235 from `checkpoint_scope_name/some_other_variable`. 236 * `'scope_variable_name': variable` - will initialize given `tf.Variable` 237 object with tensor 'scope_variable_name' from the checkpoint. 238 * `'scope_variable_name': list(variable)` - will initialize list of 239 partitioned variables with tensor 'scope_variable_name' from the checkpoint. 240 * `'/': 'scope_name/'` - will load all variables in current `scope_name` from 241 checkpoint's root (e.g. no scope). 242 243 Supports loading into partitioned variables, which are represented as 244 `'<variable>/part_<part #>'`. 245 246 Example: 247 248 ```python 249 250 # Say, '/tmp/model.ckpt' has the following tensors: 251 # -- name='old_scope_1/var1', shape=[20, 2] 252 # -- name='old_scope_1/var2', shape=[50, 4] 253 # -- name='old_scope_2/var3', shape=[100, 100] 254 255 # Create new model's variables 256 with tf.compat.v1.variable_scope('new_scope_1'): 257 var1 = tf.compat.v1.get_variable('var1', shape=[20, 2], 258 initializer=tf.compat.v1.zeros_initializer()) 259 with tf.compat.v1.variable_scope('new_scope_2'): 260 var2 = tf.compat.v1.get_variable('var2', shape=[50, 4], 261 initializer=tf.compat.v1.zeros_initializer()) 262 # Partition into 5 variables along the first axis. 263 var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100], 264 initializer=tf.compat.v1.zeros_initializer(), 265 partitioner=lambda shape, dtype: [5, 1]) 266 267 # Initialize all variables in `new_scope_1` from `old_scope_1`. 268 init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'}) 269 270 # Use names to specify which variables to initialize from checkpoint. 271 init_from_checkpoint('/tmp/model.ckpt', 272 {'old_scope_1/var1': 'new_scope_1/var1', 273 'old_scope_1/var2': 'new_scope_2/var2'}) 274 275 # Or use tf.Variable objects to identify what to initialize. 276 init_from_checkpoint('/tmp/model.ckpt', 277 {'old_scope_1/var1': var1, 278 'old_scope_1/var2': var2}) 279 280 # Initialize partitioned variables using variable's name 281 init_from_checkpoint('/tmp/model.ckpt', 282 {'old_scope_2/var3': 'new_scope_2/var3'}) 283 284 # Or specify the list of tf.Variable objects. 285 init_from_checkpoint('/tmp/model.ckpt', 286 {'old_scope_2/var3': var3._get_variable_list()}) 287 288 ``` 289 290 Args: 291 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 292 assignment_map: Dict, where keys are names of the variables in the 293 checkpoint and values are current variables or names of current variables 294 (in default graph). 295 296 Raises: 297 ValueError: If missing variables in current graph, or if missing 298 checkpoints or tensors in checkpoints. 299 """ 300 init_from_checkpoint_fn = lambda _: _init_from_checkpoint( 301 ckpt_dir_or_file, assignment_map) 302 if distribution_strategy_context.get_cross_replica_context(): 303 init_from_checkpoint_fn(None) 304 else: 305 distribution_strategy_context.get_replica_context().merge_call( 306 init_from_checkpoint_fn) 307 308 309def _init_from_checkpoint(ckpt_dir_or_file, assignment_map): 310 """See `init_from_checkpoint` for documentation.""" 311 ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file) 312 reader = load_checkpoint(ckpt_dir_or_file) 313 variable_map = reader.get_variable_to_shape_map() 314 for tensor_name_in_ckpt, current_var_or_name in sorted( 315 six.iteritems(assignment_map)): 316 var = None 317 # Check if this is Variable object or list of Variable objects (in case of 318 # partitioned variables). 319 if _is_variable(current_var_or_name) or ( 320 isinstance(current_var_or_name, list) 321 and all(_is_variable(v) for v in current_var_or_name)): 322 var = current_var_or_name 323 else: 324 store_vars = vs._get_default_variable_store()._vars # pylint:disable=protected-access 325 # Check if this variable is in var_store. 326 var = store_vars.get(current_var_or_name, None) 327 # Also check if variable is partitioned as list. 328 if var is None: 329 var = _collect_partitioned_variable(current_var_or_name, store_vars) 330 if var is not None: 331 # If 1 to 1 mapping was provided, find variable in the checkpoint. 332 if tensor_name_in_ckpt not in variable_map: 333 raise ValueError("Tensor %s is not found in %s checkpoint %s" % ( 334 tensor_name_in_ckpt, ckpt_dir_or_file, variable_map 335 )) 336 if _is_variable(var): 337 # Additional at-call-time checks. 338 if not var.get_shape().is_compatible_with( 339 variable_map[tensor_name_in_ckpt]): 340 raise ValueError( 341 "Shape of variable %s (%s) doesn't match with shape of " 342 "tensor %s (%s) from checkpoint reader." % ( 343 var.name, str(var.get_shape()), 344 tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt]) 345 )) 346 var_name = var.name 347 else: 348 var_name = ",".join(v.name for v in var) 349 _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt) 350 logging.debug("Initialize variable %s from checkpoint %s with %s", 351 var_name, ckpt_dir_or_file, tensor_name_in_ckpt) 352 else: 353 scopes = "" 354 # TODO(vihanjain): Support list of 'current_var_or_name' here. 355 if "/" in current_var_or_name: 356 scopes = current_var_or_name[:current_var_or_name.rindex("/")] 357 if not tensor_name_in_ckpt.endswith("/"): 358 raise ValueError( 359 "Assignment map with scope only name {} should map to scope only " 360 "{}. Should be 'scope/': 'other_scope/'.".format( 361 scopes, tensor_name_in_ckpt)) 362 # If scope to scope mapping was provided, find all variables in the scope 363 # and create variable to variable mapping. 364 scope_variables = set() 365 for var_name in store_vars: 366 if not scopes or var_name.startswith(scopes + "/"): 367 # Consume /part_ if partitioned variable. 368 if "/part_" in var_name: 369 var_name = var_name[:var_name.index("/part_")] 370 scope_variables.add(var_name) 371 for var_name in sorted(scope_variables): 372 # Lookup name with specified prefix and suffix from current variable. 373 # If tensor_name given is '/' (root), don't use it for full name. 374 full_tensor_name = var_name[len(scopes):] 375 if current_var_or_name != "/": 376 full_tensor_name = full_tensor_name[1:] 377 if tensor_name_in_ckpt != "/": 378 full_tensor_name = tensor_name_in_ckpt + full_tensor_name 379 # Remove trailing '/', if any, in the full_tensor_name 380 if full_tensor_name.endswith("/"): 381 full_tensor_name = full_tensor_name[:-1] 382 if full_tensor_name not in variable_map: 383 raise ValueError( 384 "Tensor %s (%s in %s) is not found in %s checkpoint" % ( 385 full_tensor_name, var_name[len(scopes) + 1:], 386 tensor_name_in_ckpt, ckpt_dir_or_file 387 )) 388 var = store_vars.get(var_name, None) 389 if var is None: 390 var = _collect_partitioned_variable(var_name, store_vars) 391 _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name) 392 logging.debug("Initialize variable %s from checkpoint %s with %s", 393 var_name, ckpt_dir_or_file, full_tensor_name) 394 395 396def _get_checkpoint_filename(ckpt_dir_or_file): 397 """Returns checkpoint filename given directory or specific checkpoint file.""" 398 if gfile.IsDirectory(ckpt_dir_or_file): 399 return checkpoint_management.latest_checkpoint(ckpt_dir_or_file) 400 return ckpt_dir_or_file 401 402 403def _set_checkpoint_initializer(variable, 404 ckpt_file, 405 tensor_name, 406 slice_spec, 407 name="checkpoint_initializer"): 408 """Overrides given variable's initialization op. 409 410 Sets variable initializer to assign op that initializes variable from tensor's 411 value in the checkpoint. 412 413 Args: 414 variable: `tf.Variable` object. 415 ckpt_file: string, full path of the checkpoint. 416 tensor_name: Name of the tensor to load from the checkpoint. 417 slice_spec: Slice specification for loading partitioned tensors. 418 name: Name of the operation. 419 """ 420 base_type = variable.dtype.base_dtype 421 # Do not colocate with variable since RestoreV2 op only runs on CPU and 422 # colocation will force variable (and other ops that colocate with variable) 423 # to be on CPU as well. It is okay to place the variable's initializer op on 424 # CPU since it will only be run once at the start. 425 with ops.device(variable.device), ops.device("/cpu:0"): 426 restore_op = io_ops.restore_v2( 427 ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] 428 429 names_to_saveables = saveable_object_util.op_list_to_dict([variable]) 430 saveable_objects = [] 431 for name, op in names_to_saveables.items(): 432 for s in saveable_object_util.saveable_objects_for_op(op, name): 433 saveable_objects.append(s) 434 435 assert len(saveable_objects) == 1 # Should be only one variable. 436 init_op = saveable_objects[0].restore([restore_op], restored_shapes=None) 437 438 # pylint:disable=protected-access 439 variable._initializer_op = init_op 440 restore_op.set_shape(variable.shape) 441 variable._initial_value = restore_op 442 # pylint:enable=protected-access 443 444 445def _set_variable_or_list_initializer(variable_or_list, ckpt_file, 446 tensor_name): 447 """Overrides initialization op of given variable or list of variables. 448 449 Calls `_set_checkpoint_initializer` for each variable in the given list of 450 variables. 451 452 Args: 453 variable_or_list: `tf.Variable` object or a list of `tf.Variable` objects. 454 ckpt_file: string, full path of the checkpoint. 455 tensor_name: Name of the tensor to load from the checkpoint. 456 457 Raises: 458 ValueError: if all objects in `variable_or_list` are not partitions of the 459 same large variable. 460 """ 461 if isinstance(variable_or_list, (list, tuple)): 462 # A set of slices. 463 slice_name = None 464 for v in variable_or_list: 465 slice_info = v._save_slice_info # pylint:disable=protected-access 466 if slice_name is None: 467 slice_name = slice_info.full_name 468 elif slice_name != slice_info.full_name: 469 raise ValueError("Slices must all be from the same tensor: %s != %s" % 470 (slice_name, slice_info.full_name)) 471 _set_checkpoint_initializer(v, ckpt_file, tensor_name, slice_info.spec) 472 else: 473 _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "") 474 475 476def _is_variable(x): 477 return (isinstance(x, variables.Variable) or 478 resource_variable_ops.is_resource_variable(x)) 479 480 481def _collect_partitioned_variable(name, all_vars): 482 """Returns list of `tf.Variable` that comprise the partitioned variable.""" 483 if name + "/part_0" in all_vars: 484 var = [] 485 i = 0 486 while name + "/part_%d" % i in all_vars: 487 var.append(all_vars[name + "/part_%d" % i]) 488 i += 1 489 return var 490 return None 491