1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tools to work with checkpoints.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import six 22 23from tensorflow.python import pywrap_tensorflow 24from tensorflow.python.distribute import distribution_strategy_context 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import io_ops 27from tensorflow.python.ops import resource_variable_ops 28from tensorflow.python.ops import variable_scope as vs 29from tensorflow.python.ops import variables 30from tensorflow.python.platform import gfile 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.training import checkpoint_management 33from tensorflow.python.training.saving import saveable_object_util 34from tensorflow.python.util.tf_export import tf_export 35 36 37__all__ = [ 38 "load_checkpoint", "load_variable", "list_variables", "init_from_checkpoint" 39] 40 41 42@tf_export("train.load_checkpoint") 43def load_checkpoint(ckpt_dir_or_file): 44 """Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`. 45 46 If `ckpt_dir_or_file` resolves to a directory with multiple checkpoints, 47 reader for the latest checkpoint is returned. 48 49 Args: 50 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint 51 file. 52 53 Returns: 54 `CheckpointReader` object. 55 56 Raises: 57 ValueError: If `ckpt_dir_or_file` resolves to a directory with no 58 checkpoints. 59 """ 60 filename = _get_checkpoint_filename(ckpt_dir_or_file) 61 if filename is None: 62 raise ValueError("Couldn't find 'checkpoint' file or checkpoints in " 63 "given directory %s" % ckpt_dir_or_file) 64 return pywrap_tensorflow.NewCheckpointReader(filename) 65 66 67@tf_export("train.load_variable") 68def load_variable(ckpt_dir_or_file, name): 69 """Returns the tensor value of the given variable in the checkpoint. 70 71 Args: 72 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 73 name: Name of the variable to return. 74 75 Returns: 76 A numpy `ndarray` with a copy of the value of this variable. 77 """ 78 # TODO(b/29227106): Fix this in the right place and remove this. 79 if name.endswith(":0"): 80 name = name[:-2] 81 reader = load_checkpoint(ckpt_dir_or_file) 82 return reader.get_tensor(name) 83 84 85@tf_export("train.list_variables") 86def list_variables(ckpt_dir_or_file): 87 """Returns list of all variables in the checkpoint. 88 89 Args: 90 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 91 92 Returns: 93 List of tuples `(name, shape)`. 94 """ 95 reader = load_checkpoint(ckpt_dir_or_file) 96 variable_map = reader.get_variable_to_shape_map() 97 names = sorted(variable_map.keys()) 98 result = [] 99 for name in names: 100 result.append((name, variable_map[name])) 101 return result 102 103 104@tf_export(v1=["train.init_from_checkpoint"]) 105def init_from_checkpoint(ckpt_dir_or_file, assignment_map): 106 """Replaces `tf.Variable` initializers so they load from a checkpoint file. 107 108 Values are not loaded immediately, but when the initializer is run 109 (typically by running a `tf.global_variables_initializer` op). 110 111 Note: This overrides default initialization ops of specified variables and 112 redefines dtype. 113 114 Assignment map supports following syntax: 115 116 * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in 117 current `scope_name` from `checkpoint_scope_name` with matching tensor 118 names. 119 * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` - 120 will initialize `scope_name/variable_name` variable 121 from `checkpoint_scope_name/some_other_variable`. 122 * `'scope_variable_name': variable` - will initialize given `tf.Variable` 123 object with tensor 'scope_variable_name' from the checkpoint. 124 * `'scope_variable_name': list(variable)` - will initialize list of 125 partitioned variables with tensor 'scope_variable_name' from the checkpoint. 126 * `'/': 'scope_name/'` - will load all variables in current `scope_name` from 127 checkpoint's root (e.g. no scope). 128 129 Supports loading into partitioned variables, which are represented as 130 `'<variable>/part_<part #>'`. 131 132 Example: 133 134 ```python 135 136 # Say, '/tmp/model.ckpt' has the following tensors: 137 # -- name='old_scope_1/var1', shape=[20, 2] 138 # -- name='old_scope_1/var2', shape=[50, 4] 139 # -- name='old_scope_2/var3', shape=[100, 100] 140 141 # Create new model's variables 142 with tf.variable_scope('new_scope_1'): 143 var1 = tf.get_variable('var1', shape=[20, 2], 144 initializer=tf.zeros_initializer()) 145 with tf.variable_scope('new_scope_2'): 146 var2 = tf.get_variable('var2', shape=[50, 4], 147 initializer=tf.zeros_initializer()) 148 # Partition into 5 variables along the first axis. 149 var3 = tf.get_variable(name='var3', shape=[100, 100], 150 initializer=tf.zeros_initializer(), 151 partitioner=lambda shape, dtype: [5, 1]) 152 153 # Initialize all variables in `new_scope_1` from `old_scope_1`. 154 init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'}) 155 156 # Use names to specify which variables to initialize from checkpoint. 157 init_from_checkpoint('/tmp/model.ckpt', 158 {'old_scope_1/var1': 'new_scope_1/var1', 159 'old_scope_1/var2': 'new_scope_2/var2'}) 160 161 # Or use tf.Variable objects to identify what to initialize. 162 init_from_checkpoint('/tmp/model.ckpt', 163 {'old_scope_1/var1': var1, 164 'old_scope_1/var2': var2}) 165 166 # Initialize partitioned variables using variable's name 167 init_from_checkpoint('/tmp/model.ckpt', 168 {'old_scope_2/var3': 'new_scope_2/var3'}) 169 170 # Or specify the list of tf.Variable objects. 171 init_from_checkpoint('/tmp/model.ckpt', 172 {'old_scope_2/var3': var3._get_variable_list()}) 173 174 ``` 175 176 Args: 177 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 178 assignment_map: Dict, where keys are names of the variables in the 179 checkpoint and values are current variables or names of current variables 180 (in default graph). 181 182 Raises: 183 ValueError: If missing variables in current graph, or if missing 184 checkpoints or tensors in checkpoints. 185 """ 186 init_from_checkpoint_fn = lambda _: _init_from_checkpoint( 187 ckpt_dir_or_file, assignment_map) 188 if distribution_strategy_context.get_cross_replica_context(): 189 init_from_checkpoint_fn(None) 190 else: 191 distribution_strategy_context.get_replica_context().merge_call( 192 init_from_checkpoint_fn) 193 194 195def _init_from_checkpoint(ckpt_dir_or_file, assignment_map): 196 """See `init_from_checkpoint` for documentation.""" 197 ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file) 198 reader = load_checkpoint(ckpt_dir_or_file) 199 variable_map = reader.get_variable_to_shape_map() 200 for tensor_name_in_ckpt, current_var_or_name in sorted( 201 six.iteritems(assignment_map)): 202 var = None 203 # Check if this is Variable object or list of Variable objects (in case of 204 # partitioned variables). 205 if _is_variable(current_var_or_name) or ( 206 isinstance(current_var_or_name, list) 207 and all(_is_variable(v) for v in current_var_or_name)): 208 var = current_var_or_name 209 else: 210 store_vars = vs._get_default_variable_store()._vars # pylint:disable=protected-access 211 # Check if this variable is in var_store. 212 var = store_vars.get(current_var_or_name, None) 213 # Also check if variable is partitioned as list. 214 if var is None: 215 var = _collect_partitioned_variable(current_var_or_name, store_vars) 216 if var is not None: 217 # If 1 to 1 mapping was provided, find variable in the checkpoint. 218 if tensor_name_in_ckpt not in variable_map: 219 raise ValueError("Tensor %s is not found in %s checkpoint %s" % ( 220 tensor_name_in_ckpt, ckpt_dir_or_file, variable_map 221 )) 222 if _is_variable(var): 223 # Additional at-call-time checks. 224 if not var.get_shape().is_compatible_with( 225 variable_map[tensor_name_in_ckpt]): 226 raise ValueError( 227 "Shape of variable %s (%s) doesn't match with shape of " 228 "tensor %s (%s) from checkpoint reader." % ( 229 var.name, str(var.get_shape()), 230 tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt]) 231 )) 232 var_name = var.name 233 else: 234 var_name = ",".join([v.name for v in var]) 235 _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt) 236 logging.debug("Initialize variable %s from checkpoint %s with %s", 237 var_name, ckpt_dir_or_file, tensor_name_in_ckpt) 238 else: 239 scopes = "" 240 # TODO(vihanjain): Support list of 'current_var_or_name' here. 241 if "/" in current_var_or_name: 242 scopes = current_var_or_name[:current_var_or_name.rindex("/")] 243 if not tensor_name_in_ckpt.endswith("/"): 244 raise ValueError( 245 "Assignment map with scope only name {} should map to scope only " 246 "{}. Should be 'scope/': 'other_scope/'.".format( 247 scopes, tensor_name_in_ckpt)) 248 # If scope to scope mapping was provided, find all variables in the scope 249 # and create variable to variable mapping. 250 scope_variables = set() 251 for var_name in store_vars: 252 if not scopes or var_name.startswith(scopes + "/"): 253 # Consume /part_ if partitioned variable. 254 if "/part_" in var_name: 255 var_name = var_name[:var_name.index("/part_")] 256 scope_variables.add(var_name) 257 for var_name in sorted(scope_variables): 258 # Lookup name with specified prefix and suffix from current variable. 259 # If tensor_name given is '/' (root), don't use it for full name. 260 full_tensor_name = var_name[len(scopes):] 261 if current_var_or_name != "/": 262 full_tensor_name = full_tensor_name[1:] 263 if tensor_name_in_ckpt != "/": 264 full_tensor_name = tensor_name_in_ckpt + full_tensor_name 265 # Remove trailing '/', if any, in the full_tensor_name 266 if full_tensor_name.endswith("/"): 267 full_tensor_name = full_tensor_name[:-1] 268 if full_tensor_name not in variable_map: 269 raise ValueError( 270 "Tensor %s (%s in %s) is not found in %s checkpoint" % ( 271 full_tensor_name, var_name[len(scopes) + 1:], 272 tensor_name_in_ckpt, ckpt_dir_or_file 273 )) 274 var = store_vars.get(var_name, None) 275 if var is None: 276 var = _collect_partitioned_variable(var_name, store_vars) 277 _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name) 278 logging.debug("Initialize variable %s from checkpoint %s with %s", 279 var_name, ckpt_dir_or_file, full_tensor_name) 280 281 282def _get_checkpoint_filename(ckpt_dir_or_file): 283 """Returns checkpoint filename given directory or specific checkpoint file.""" 284 if gfile.IsDirectory(ckpt_dir_or_file): 285 return checkpoint_management.latest_checkpoint(ckpt_dir_or_file) 286 return ckpt_dir_or_file 287 288 289def _set_checkpoint_initializer(variable, 290 ckpt_file, 291 tensor_name, 292 slice_spec, 293 name="checkpoint_initializer"): 294 """Overrides given variable's initialization op. 295 296 Sets variable initializer to assign op that initializes variable from tensor's 297 value in the checkpoint. 298 299 Args: 300 variable: `tf.Variable` object. 301 ckpt_file: string, full path of the checkpoint. 302 tensor_name: Name of the tensor to load from the checkpoint. 303 slice_spec: Slice specification for loading partitioned tensors. 304 name: Name of the operation. 305 """ 306 base_type = variable.dtype.base_dtype 307 # Do not colocate with variable since RestoreV2 op only runs on CPU and 308 # colocation will force variable (and other ops that colocate with variable) 309 # to be on CPU as well. It is okay to place the variable's initializer op on 310 # CPU since it will only be run once at the start. 311 with ops.device(variable.device), ops.device("/cpu:0"): 312 restore_op = io_ops.restore_v2( 313 ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] 314 315 names_to_saveables = saveable_object_util.op_list_to_dict([variable]) 316 saveable_objects = [] 317 for name, op in names_to_saveables.items(): 318 for s in saveable_object_util.saveable_objects_for_op(op, name): 319 saveable_objects.append(s) 320 321 assert len(saveable_objects) == 1 # Should be only one variable. 322 init_op = saveable_objects[0].restore([restore_op], restored_shapes=None) 323 324 # pylint:disable=protected-access 325 variable._initializer_op = init_op 326 restore_op.set_shape(variable.shape) 327 variable._initial_value = restore_op 328 # pylint:enable=protected-access 329 330 331def _set_variable_or_list_initializer(variable_or_list, ckpt_file, 332 tensor_name): 333 """Overrides initialization op of given variable or list of variables. 334 335 Calls `_set_checkpoint_initializer` for each variable in the given list of 336 variables. 337 338 Args: 339 variable_or_list: `tf.Variable` object or a list of `tf.Variable` objects. 340 ckpt_file: string, full path of the checkpoint. 341 tensor_name: Name of the tensor to load from the checkpoint. 342 343 Raises: 344 ValueError: if all objects in `variable_or_list` are not partitions of the 345 same large variable. 346 """ 347 if isinstance(variable_or_list, (list, tuple)): 348 # A set of slices. 349 slice_name = None 350 for v in variable_or_list: 351 slice_info = v._save_slice_info # pylint:disable=protected-access 352 if slice_name is None: 353 slice_name = slice_info.full_name 354 elif slice_name != slice_info.full_name: 355 raise ValueError("Slices must all be from the same tensor: %s != %s" % 356 (slice_name, slice_info.full_name)) 357 _set_checkpoint_initializer(v, ckpt_file, tensor_name, slice_info.spec) 358 else: 359 _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "") 360 361 362def _is_variable(x): 363 return (isinstance(x, variables.Variable) or 364 resource_variable_ops.is_resource_variable(x)) 365 366def _collect_partitioned_variable(name, all_vars): 367 """Returns list of `tf.Variable` that comprise the partitioned variable.""" 368 if name + "/part_0" in all_vars: 369 var = [] 370 i = 0 371 while name + "/part_%d" % i in all_vars: 372 var.append(all_vars[name + "/part_%d" % i]) 373 i += 1 374 return var 375 return None 376