1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for working with and creating SaveableObjects.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import functools 21import six 22 23from tensorflow.python.eager import context 24from tensorflow.python.eager import def_function 25 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import device as pydev 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_spec 31from tensorflow.python.framework import tensor_util 32from tensorflow.python.framework import type_spec 33 34 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import resource_variable_ops 37from tensorflow.python.ops import state_ops 38from tensorflow.python.ops import variables 39from tensorflow.python.platform import tf_logging as logging 40from tensorflow.python.training.saving import saveable_object 41from tensorflow.python.training.tracking import base as trackable 42from tensorflow.python.util import nest 43from tensorflow.python.util import object_identity 44 45 46# Op names which identify variable reads which should be saved. 47_VARIABLE_OPS = set(["Variable", 48 "VariableV2", 49 "AutoReloadVariable", 50 "VarHandleOp", 51 "ReadVariableOp"]) 52 53 54def set_cpu0(device_string): 55 """Creates a new device string based on `device_string` but using /CPU:0. 56 57 If the device is already on /CPU:0, this is a no-op. 58 59 Args: 60 device_string: A device string. 61 62 Returns: 63 A device string. 64 """ 65 parsed_device = pydev.DeviceSpec.from_string(device_string) 66 parsed_device = parsed_device.replace(device_type="CPU", device_index=0) 67 return parsed_device.to_string() 68 69 70class ReferenceVariableSaveable(saveable_object.SaveableObject): 71 """SaveableObject implementation that handles reference variables.""" 72 73 def __init__(self, var, slice_spec, name): 74 spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype) 75 super(ReferenceVariableSaveable, self).__init__(var, [spec], name) 76 77 def restore(self, restored_tensors, restored_shapes): 78 restored_tensor = restored_tensors[0] 79 if restored_shapes is not None: 80 restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) 81 return state_ops.assign( 82 self.op, 83 restored_tensor, 84 validate_shape=restored_shapes is None and 85 self.op.get_shape().is_fully_defined()) 86 87 88class ResourceVariableSaveable(saveable_object.SaveableObject): 89 """SaveableObject implementation that handles ResourceVariables.""" 90 91 def __init__(self, var, slice_spec, name): 92 self._var_device = var.device 93 self._var_shape = var.shape 94 if isinstance(var, ops.Tensor): 95 self.handle_op = var.op.inputs[0] 96 tensor = var 97 elif resource_variable_ops.is_resource_variable(var): 98 99 def _read_variable_closure(v): 100 def f(): 101 with ops.device(v.device): 102 if context.executing_eagerly() and not v.is_initialized(): 103 # A SaveSpec tensor value of `None` indicates that the variable is 104 # uninitialized. 105 return None 106 x = v.read_value() 107 # To allow variables placed on non-CPU devices to be checkpointed, 108 # we copy them to CPU on the same machine first. 109 with ops.device("/device:CPU:0"): 110 return array_ops.identity(x) 111 112 return f 113 114 self.handle_op = var.handle 115 tensor = _read_variable_closure(var) 116 else: 117 raise ValueError( 118 "Saveable is neither a resource variable nor a read operation." 119 " Got: %s" % repr(var)) 120 spec = saveable_object.SaveSpec(tensor, slice_spec, name, 121 dtype=var.dtype, device=var.device) 122 super(ResourceVariableSaveable, self).__init__(var, [spec], name) 123 124 def restore(self, restored_tensors, restored_shapes): 125 restored_tensor = restored_tensors[0] 126 if restored_shapes is not None: 127 restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) 128 # Copy the restored tensor to the variable's device. 129 with ops.device(self._var_device): 130 restored_tensor = array_ops.identity(restored_tensor) 131 return resource_variable_ops.shape_safe_assign_variable_handle( 132 self.handle_op, self._var_shape, restored_tensor) 133 134 135def _tensor_comes_from_variable(v): 136 return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS 137 138 139def saveable_objects_for_op(op, name): 140 """Create `SaveableObject`s from an operation. 141 142 Args: 143 op: A variable, operation, or SaveableObject to coerce into a 144 SaveableObject. 145 name: A string name for the SaveableObject. 146 147 Yields: 148 `SaveableObject`s which together save/restore `op`. 149 150 Raises: 151 TypeError: If `name` is not a string. 152 ValueError: For operations with no known conversion to SaveableObject. 153 """ 154 if not isinstance(name, six.string_types): 155 raise TypeError( 156 "names_to_saveables must be a dict mapping string names to " 157 "trackable operations. Name is not a string: %s" % name) 158 if isinstance(op, saveable_object.SaveableObject): 159 yield op 160 elif isinstance(op, (list, tuple, variables.PartitionedVariable)): 161 if isinstance(op, variables.PartitionedVariable): 162 op = list(op) 163 # A set of slices. 164 slice_name = None 165 # pylint: disable=protected-access 166 for variable in op: 167 if isinstance(variable, saveable_object.SaveableObject): 168 yield variable 169 continue 170 if not isinstance(variable, variables.Variable): 171 raise ValueError("Slices must all be Variables: %s" % variable) 172 if not variable._save_slice_info: 173 raise ValueError("Slices must all be slices: %s" % variable) 174 if slice_name is None: 175 slice_name = variable._save_slice_info.full_name 176 elif slice_name != variable._save_slice_info.full_name: 177 raise ValueError( 178 "Slices must all be from the same tensor: %s != %s" % 179 (slice_name, variable._save_slice_info.full_name)) 180 if variable.op.type in ["Variable", "VariableV2", 181 "AutoReloadVariable"]: 182 yield ReferenceVariableSaveable( 183 variable, variable._save_slice_info.spec, name) 184 else: 185 yield ResourceVariableSaveable(variable, variable._save_slice_info.spec, 186 name) 187 # pylint: enable=protected-access 188 elif isinstance(op, trackable.Trackable) and not isinstance( 189 op, variables.Variable): 190 # pylint: disable=protected-access 191 for attr, factory in op._gather_saveables_for_checkpoint().items(): 192 if attr == trackable.VARIABLE_VALUE_KEY: 193 # Keep original name for classes masquerading as variables. 194 full_name = name 195 else: 196 full_name = name + "_" + attr 197 op = (factory(full_name) if callable(factory) else factory) 198 for op in saveable_objects_for_op(op, op.name): 199 yield op 200 # pylint: enable=protected-access 201 else: 202 # A variable or tensor. 203 if isinstance(op, resource_variable_ops.BaseResourceVariable): 204 if op._in_graph_mode: # pylint: disable=protected-access 205 variable = op._graph_element # pylint: disable=protected-access 206 else: 207 variable = op 208 yield ResourceVariableSaveable(variable, "", name) 209 else: 210 if context.executing_eagerly(): 211 raise ValueError("Can only save/restore ResourceVariables when " 212 "executing eagerly, got type: %s." % type(op)) 213 214 variable = ops.convert_to_tensor(op, as_ref=True) 215 if not _tensor_comes_from_variable(variable): 216 raise TypeError("names_to_saveables must be a dict mapping string " 217 "names to Tensors/Variables. Not a variable: %s" % 218 variable) 219 if variable.op.type in ["Variable", "VariableV2", 220 "AutoReloadVariable"]: 221 yield ReferenceVariableSaveable(variable, "", name) 222 else: 223 yield ResourceVariableSaveable(variable, "", name) 224 225 226def op_list_to_dict(op_list, convert_variable_to_tensor=True): 227 """Create a dictionary of names to operation lists. 228 229 Args: 230 op_list: A (nested) list, tuple, or set of Variables or SaveableObjects. 231 convert_variable_to_tensor: Whether or not to convert single Variables 232 with no slice info into Tensors. 233 234 Returns: 235 A dictionary of names to the operations that must be saved under 236 that name. Variables with save_slice_info are grouped together under the 237 same key in no particular order. 238 239 Raises: 240 TypeError: If the type of op_list or its elements is not supported. 241 ValueError: If at least two saveables share the same name. 242 """ 243 if not isinstance(op_list, (list, tuple, set)): 244 raise TypeError("Variables to save should be passed in a dict or a " 245 "list: %s" % op_list) 246 # List casting is necessary to support sets. 247 op_list = nest.flatten(list(op_list)) 248 # When ResourceVariables are converted to Tensors, read ops are added to the 249 # graph. Sorting the op_list ensures that the resulting graph is always 250 # constructed in a deterministic way: 251 op_list = sorted(op_list, key=lambda x: x.name) 252 names_to_saveables = {} 253 # pylint: disable=protected-access 254 for var in op_list: 255 resource_or_ref_variable = ( 256 isinstance(var, resource_variable_ops.BaseResourceVariable) or 257 isinstance(var, variables.RefVariable)) 258 259 if isinstance(var, saveable_object.SaveableObject): 260 names_to_saveables[var.name] = var 261 elif isinstance(var, variables.PartitionedVariable): 262 if var.name in names_to_saveables: 263 raise ValueError("At least two variables have the same name: %s" % 264 var.name) 265 names_to_saveables[var.name] = var 266 elif isinstance(var, variables.Variable) and var._save_slice_info: 267 name = var._save_slice_info.full_name 268 if name in names_to_saveables: 269 if not isinstance(names_to_saveables[name], list): 270 raise ValueError("Mixing slices and non-slices with the same name: " 271 "%s" % name) 272 names_to_saveables[name].append(var) 273 else: 274 names_to_saveables[name] = [var] 275 elif isinstance(var, trackable.Trackable) and not resource_or_ref_variable: 276 trackable_saveables = [ 277 (factory() if callable(factory) else factory) 278 for factory in var._gather_saveables_for_checkpoint().values()] 279 names_to_saveables.update( 280 op_list_to_dict(trackable_saveables)) 281 else: 282 # Variables (reference and resource) have an _in_graph_mode property 283 # indicating whether they were created in a graph building context. We 284 # also get Tensors when graph building, which do not have this property. 285 if not getattr(var, "_in_graph_mode", True): 286 if not isinstance(var, resource_variable_ops.BaseResourceVariable): 287 raise ValueError( 288 "Can only save/restore ResourceVariables when eager execution " 289 "is enabled, type: %s." % type(var)) 290 set_var = names_to_saveables.setdefault(var._shared_name, var) 291 if set_var is not var: 292 raise ValueError( 293 ("Two different ResourceVariable objects with the same " 294 "shared_name '%s' were passed to the Saver. This likely means " 295 "that they were created in different Graphs or isoWlation " 296 "contexts, and may not be checkpointed together.") % 297 (var._shared_name,)) 298 else: 299 if convert_variable_to_tensor: 300 if isinstance(var, resource_variable_ops.BaseResourceVariable): 301 var = var._graph_element # pylint: disable=protected-access 302 else: 303 var = ops.convert_to_tensor(var, as_ref=True) 304 if not _tensor_comes_from_variable(var): 305 raise TypeError("Variable to save is not a Variable: %s" % var) 306 if var.op.type == "ReadVariableOp": 307 name = var.op.inputs[0].op.name 308 else: 309 name = var.op.name 310 if name in names_to_saveables: 311 raise ValueError("At least two variables have the same name: %s" % 312 name) 313 names_to_saveables[name] = var 314 315 # pylint: enable=protected-access 316 return names_to_saveables 317 318 319def _add_saveable(saveables, seen_ops, saveable): 320 """Adds the saveable to the saveables list. 321 322 Args: 323 saveables: List to append the SaveableObject to. 324 seen_ops: Set of the ops of the saveables already processed. Used to 325 check that each saveable is only saved once. 326 saveable: The saveable. 327 328 Raises: 329 ValueError: If the saveable has already been processed. 330 """ 331 if saveable.op is not None and saveable.op in seen_ops: 332 raise ValueError("The same saveable will be restored with two names: %s" % 333 saveable.name) 334 saveables.append(saveable) 335 seen_ops.add(saveable.op) 336 337 338def validate_and_slice_inputs(names_to_saveables): 339 """Returns the variables and names that will be used for a Saver. 340 341 Args: 342 names_to_saveables: A dict (k, v) where k is the name of an operation and 343 v is an operation to save or a BaseSaverBuilder.Saver. 344 345 Returns: 346 A list of SaveableObjects. 347 348 Raises: 349 TypeError: If any of the keys are not strings or any of the 350 values are not one of Tensor or Variable or a trackable operation. 351 ValueError: If the same operation is given in more than one value 352 (this also applies to slices of SlicedVariables). 353 """ 354 if not isinstance(names_to_saveables, dict): 355 names_to_saveables = op_list_to_dict(names_to_saveables) 356 357 saveables = [] 358 seen_ops = object_identity.ObjectIdentitySet() 359 for name, op in sorted(names_to_saveables.items(), 360 # Avoid comparing ops, sort only by name. 361 key=lambda x: x[0]): 362 for converted_saveable_object in saveable_objects_for_op(op, name): 363 _add_saveable(saveables, seen_ops, converted_saveable_object) 364 return saveables 365 366 367def trace_save_restore_functions(object_to_save): 368 """Gathers all SaveableObjects and traces the save and restore ops.""" 369 saveable_map = {} # Maps name -> (save function, restore function) 370 for name, saveable_factory in ( 371 object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access 372 if not callable(saveable_factory): 373 if isinstance(saveable_factory, saveable_object.SaveableObject): 374 logging.debug( 375 "Trackable {} should return callable factories, not SaveableObjects" 376 " in `_gather_saveables_for_checkpoint`. This could lead to " 377 "problems loading the SavedModel back into Python." 378 .format(object_to_save)) 379 continue 380 381 if is_factory_for_restored_saveable_object(saveable_factory): 382 saveable_map[name] = (saveable_factory.keywords["save_function"], 383 saveable_factory.keywords["restore_function"]) 384 else: 385 concrete_save_fn, concrete_restore_fn = _trace_save_and_restore_function( 386 saveable_factory, object_to_save) 387 if concrete_save_fn is not None: 388 saveable_map[name] = (concrete_save_fn, concrete_restore_fn) 389 return saveable_map 390 391 392def _trace_save_and_restore_function(saveable_factory, object_to_save): 393 """Traces the save and restore concrete functions.""" 394 saveables = [] 395 396 @def_function.function( 397 input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 398 def save_fn(checkpoint_key): 399 maybe_saveable = saveable_factory(name=checkpoint_key) 400 if isinstance(maybe_saveable, saveable_object.SaveableObject): 401 maybe_saveable = [maybe_saveable] 402 saveables[:] = maybe_saveable 403 404 # Return list of all SaveSpecs created by the factory. 405 ret = [] 406 for saveable in saveables: 407 for spec in saveable.specs: 408 ret.append({"name": spec.name, "tensor": spec.tensor, 409 "slice_spec": spec.slice_spec}) 410 return ret 411 412 concrete_save_fn = save_fn.get_concrete_function() 413 if any(isinstance(saveable, trackable.PythonStateSaveable) 414 for saveable in saveables): 415 logging.warn( 416 "Note that object {} stores python values into the checkpoint. " 417 "These values will not be restored when loading the SavedModel " 418 "into python.".format(object_to_save)) 419 return None, None 420 if any(isinstance(saveable, trackable.NoRestoreSaveable) 421 for saveable in saveables): 422 return None, None 423 424 restored_type_specs = [] 425 tensor_structure = [] 426 for saveable in saveables: 427 saveable_tensor_structure = [] 428 tensor_structure.append(saveable_tensor_structure) 429 for spec in saveable.specs: 430 restored_type_specs.append(type_spec.type_spec_from_value(spec.tensor)) 431 saveable_tensor_structure.append(spec.name) 432 433 @def_function.function(input_signature=restored_type_specs) 434 def restore_fn(*restored_tensors): 435 structured_restored_tensors = nest.pack_sequence_as( 436 tensor_structure, restored_tensors) 437 for saveable, restored_tensors in zip(saveables, 438 structured_restored_tensors): 439 saveable.restore(restored_tensors, restored_shapes=None) 440 return 1 441 442 concrete_restore_fn = restore_fn.get_concrete_function() 443 return concrete_save_fn, concrete_restore_fn 444 445 446class RestoredSaveableObject(saveable_object.SaveableObject): 447 """SaveableObject restored from SavedModel using the traced save/restore.""" 448 449 def __init__(self, save_function, restore_function, name): 450 self.save_function = save_function 451 self.restore_function = restore_function 452 453 if tensor_util.is_tf_type(name): 454 name_tensor = name 455 else: 456 with ops.init_scope(): 457 name_tensor = constant_op.constant(name) 458 tensors = save_function(name_tensor) 459 specs = [saveable_object.SaveSpec(x["tensor"], x["slice_spec"], x["name"]) 460 for x in tensors] 461 super(RestoredSaveableObject, self).__init__(None, specs, name) 462 463 def restore(self, restored_tensors, restored_shapes): 464 del restored_shapes # unused 465 return self.restore_function( 466 *[restored_tensors[i] for i in range(len(self.specs))]) 467 468 469def restored_saved_object_factory(save_function, restore_function): 470 return functools.partial(RestoredSaveableObject, 471 save_function=save_function, 472 restore_function=restore_function) 473 474 475def create_saveable_object(factory, name, call_with_mapped_captures): 476 """Creates a SaveableObject while potentially in a different graph. 477 478 When creating the frozen saver for SavedModel, the save and restore ops are 479 placed in a separate graph. Since RestoredSaveableObject uses tf.functions to 480 save and restore, the function captures must be mapped to the new graph. 481 482 Args: 483 factory: Factory method for creating the SaveableObject. 484 name: Checkpoint key of this SaveableObject. 485 call_with_mapped_captures: Helper that calls a tf.function while remapping 486 the captures. 487 488 Returns: 489 a SaveableObject. 490 """ 491 if (call_with_mapped_captures is None or 492 not is_factory_for_restored_saveable_object(factory)): 493 return factory(name=name) 494 495 concrete_save_fn = factory.keywords["save_function"] 496 def save_fn(name): 497 return call_with_mapped_captures(concrete_save_fn, [name]) 498 499 concrete_restore_fn = factory.keywords["restore_function"] 500 def restore_fn(*restored_tensors): 501 return call_with_mapped_captures(concrete_restore_fn, restored_tensors) 502 503 return factory(save_function=save_fn, restore_function=restore_fn, name=name) 504 505 506def is_factory_for_restored_saveable_object(factory): 507 return (isinstance(factory, functools.partial) and 508 factory.func is RestoredSaveableObject) 509