1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for working with and creating SaveableObjects.""" 16import functools 17 18from tensorflow.python.checkpoint import saveable_compat 19from tensorflow.python.eager import context 20from tensorflow.python.eager import def_function 21 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import device as pydev 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_spec 27from tensorflow.python.framework import tensor_util 28from tensorflow.python.framework import type_spec 29 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import resource_variable_ops 32from tensorflow.python.ops import state_ops 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import tf_logging as logging 35from tensorflow.python.trackable import base as trackable 36from tensorflow.python.trackable import python_state 37from tensorflow.python.trackable import trackable_utils 38from tensorflow.python.training.saving import saveable_object 39from tensorflow.python.types import core 40from tensorflow.python.util import compat 41from tensorflow.python.util import nest 42from tensorflow.python.util import object_identity 43from tensorflow.python.util.tf_export import tf_export 44 45# Op names which identify variable reads which should be saved. 46_VARIABLE_OPS = set(["Variable", 47 "VariableV2", 48 "AutoReloadVariable", 49 "VarHandleOp", 50 "ReadVariableOp"]) 51 52 53def set_cpu0(device_string): 54 """Creates a new device string based on `device_string` but using /CPU:0. 55 56 If the device is already on /CPU:0, this is a no-op. 57 58 Args: 59 device_string: A device string. 60 61 Returns: 62 A device string. 63 """ 64 parsed_device = pydev.DeviceSpec.from_string(device_string) 65 parsed_device = parsed_device.replace(device_type="CPU", device_index=0) 66 return parsed_device.to_string() 67 68 69class ReferenceVariableSaveable(saveable_object.SaveableObject): 70 """SaveableObject implementation that handles reference variables.""" 71 72 def __init__(self, var, slice_spec, name): 73 spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype) 74 super(ReferenceVariableSaveable, self).__init__(var, [spec], name) 75 76 def restore(self, restored_tensors, restored_shapes): 77 restored_tensor = restored_tensors[0] 78 if restored_shapes is not None: 79 restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) 80 return state_ops.assign( 81 self.op, 82 restored_tensor, 83 validate_shape=restored_shapes is None and 84 self.op.get_shape().is_fully_defined()) 85 86 87class ResourceVariableSaveable(saveable_object.SaveableObject): 88 """SaveableObject implementation that handles ResourceVariables.""" 89 90 def __init__(self, var, slice_spec, name): 91 self._var_device = var.device 92 self._var_shape = var.shape 93 if isinstance(var, ops.Tensor): 94 self.handle_op = var.op.inputs[0] 95 tensor = var 96 elif resource_variable_ops.is_resource_variable(var): 97 98 def _read_variable_closure(v): 99 def f(): 100 with ops.device(v.device): 101 if context.executing_eagerly() and not v.is_initialized(): 102 # A SaveSpec tensor value of `None` indicates that the variable is 103 # uninitialized. 104 return None 105 # Read the variable without making a copy to limit memory usage. 106 x = v.read_value_no_copy() 107 # To allow variables placed on non-CPU devices to be checkpointed, 108 # we copy them to CPU on the same machine first. 109 with ops.device("/device:CPU:0"): 110 return array_ops.identity(x) 111 112 return f 113 114 self.handle_op = var.handle 115 tensor = _read_variable_closure(var) 116 else: 117 raise ValueError( 118 "Saveable is neither a resource variable nor a read operation." 119 f" Got: {repr(var)}") 120 spec = saveable_object.SaveSpec(tensor, slice_spec, name, 121 dtype=var.dtype, device=var.device) 122 super(ResourceVariableSaveable, self).__init__(var, [spec], name) 123 124 def restore(self, restored_tensors, restored_shapes): 125 """Restores tensors. Raises ValueError if incompatible shape found.""" 126 restored_tensor = restored_tensors[0] 127 if restored_shapes is not None: 128 restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) 129 # Copy the restored tensor to the variable's device. 130 with ops.device(self._var_device): 131 restored_tensor = array_ops.identity(restored_tensor) 132 try: 133 assigned_variable = resource_variable_ops.shape_safe_assign_variable_handle( 134 self.handle_op, self._var_shape, restored_tensor) 135 except ValueError as e: 136 raise ValueError( 137 f"Received incompatible tensor with shape {restored_tensor.shape} " 138 f"when attempting to restore variable with shape {self._var_shape} " 139 f"and name {self.name}.") from e 140 return assigned_variable 141 142 143def _tensor_comes_from_variable(v): 144 return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS 145 146 147def saveable_objects_for_op(op, name): 148 """Create `SaveableObject`s from an operation. 149 150 Args: 151 op: A variable, operation, or SaveableObject to coerce into a 152 SaveableObject. 153 name: A string name for the SaveableObject. 154 155 Yields: 156 `SaveableObject`s which together save/restore `op`. 157 158 Raises: 159 TypeError: If `name` is not a string. 160 ValueError: For operations with no known conversion to SaveableObject. 161 """ 162 if not isinstance(name, str): 163 raise TypeError( 164 "names_to_saveables must be a dict mapping string names to " 165 f"trackable operations. Name is not a string: {name}") 166 if isinstance(op, saveable_object.SaveableObject): 167 yield op 168 elif isinstance(op, (list, tuple, variables.PartitionedVariable)): 169 if isinstance(op, variables.PartitionedVariable): 170 op = list(op) 171 # A set of slices. 172 slice_name = None 173 # pylint: disable=protected-access 174 for variable in op: 175 if isinstance(variable, saveable_object.SaveableObject): 176 yield variable 177 continue 178 if not isinstance(variable, variables.Variable): 179 raise ValueError(f"Slices must all be Variables: {variable}") 180 if not variable._save_slice_info: 181 raise ValueError(f"Slices must all be slices: {variable}") 182 if slice_name is None: 183 slice_name = variable._save_slice_info.full_name 184 elif slice_name != variable._save_slice_info.full_name: 185 raise ValueError( 186 f"Slices must all be from the same tensor: {slice_name} != " 187 f"{variable._save_slice_info.full_name}") 188 if variable.op.type in ["Variable", "VariableV2", 189 "AutoReloadVariable"]: 190 yield ReferenceVariableSaveable( 191 variable, variable._save_slice_info.spec, name) 192 else: 193 yield ResourceVariableSaveable(variable, variable._save_slice_info.spec, 194 name) 195 # pylint: enable=protected-access 196 elif isinstance(op, trackable.Trackable) and not isinstance( 197 op, variables.Variable): 198 # pylint: disable=protected-access 199 for attr, factory in saveable_objects_from_trackable(op).items(): 200 if attr == trackable.VARIABLE_VALUE_KEY: 201 # Keep original name for classes masquerading as variables. 202 full_name = name 203 else: 204 full_name = name + "_" + attr 205 op = (factory(full_name) if callable(factory) else factory) 206 for op in saveable_objects_for_op(op, op.name): 207 yield op 208 # pylint: enable=protected-access 209 else: 210 # A variable or tensor. 211 if isinstance(op, resource_variable_ops.BaseResourceVariable): 212 if op._in_graph_mode: # pylint: disable=protected-access 213 variable = op._graph_element # pylint: disable=protected-access 214 else: 215 variable = op 216 yield ResourceVariableSaveable(variable, "", name) 217 else: 218 if context.executing_eagerly(): 219 raise ValueError("Can only save/restore ResourceVariables when " 220 f"executing eagerly, got type: {type(op)}.") 221 222 variable = ops.convert_to_tensor(op, as_ref=True) 223 if not _tensor_comes_from_variable(variable): 224 raise TypeError( 225 "names_to_saveables must be a dict mapping string " 226 f"names to Tensors/Variables. Not a variable: {variable}") 227 if variable.op.type in ["Variable", "VariableV2", 228 "AutoReloadVariable"]: 229 yield ReferenceVariableSaveable(variable, "", name) 230 else: 231 yield ResourceVariableSaveable(variable, "", name) 232 233 234def op_list_to_dict(op_list, convert_variable_to_tensor=True): 235 """Create a dictionary of names to operation lists. 236 237 Args: 238 op_list: A (nested) list, tuple, or set of Variables or SaveableObjects. 239 convert_variable_to_tensor: Whether or not to convert single Variables 240 with no slice info into Tensors. 241 242 Returns: 243 A dictionary of names to the operations that must be saved under 244 that name. Variables with save_slice_info are grouped together under the 245 same key in no particular order. 246 247 Raises: 248 TypeError: If the type of op_list or its elements is not supported. 249 ValueError: If at least two saveables share the same name. 250 """ 251 if not isinstance(op_list, (list, tuple, set)): 252 raise TypeError("Variables to save should be passed in a dict or a " 253 f"list. Got {op_list}") 254 # List casting is necessary to support sets. 255 op_list = nest.flatten(list(op_list)) 256 # When ResourceVariables are converted to Tensors, read ops are added to the 257 # graph. Sorting the op_list ensures that the resulting graph is always 258 # constructed in a deterministic way: 259 op_list = sorted(op_list, key=lambda x: x.name) 260 names_to_saveables = {} 261 # pylint: disable=protected-access 262 for var in op_list: 263 resource_or_ref_variable = ( 264 isinstance(var, resource_variable_ops.BaseResourceVariable) or 265 isinstance(var, variables.RefVariable)) 266 267 if isinstance(var, saveable_object.SaveableObject): 268 names_to_saveables[var.name] = var 269 elif isinstance(var, variables.PartitionedVariable): 270 if var.name in names_to_saveables: 271 raise ValueError( 272 f"At least two variables have the same name: {var.name}") 273 names_to_saveables[var.name] = var 274 elif isinstance(var, variables.Variable) and var._save_slice_info: 275 name = var._save_slice_info.full_name 276 if name in names_to_saveables: 277 if not isinstance(names_to_saveables[name], list): 278 raise ValueError("Mixing slices and non-slices with the same name: " 279 f"{name}") 280 names_to_saveables[name].append(var) 281 else: 282 names_to_saveables[name] = [var] 283 elif isinstance(var, trackable.Trackable) and not resource_or_ref_variable: 284 trackable_saveables = [ 285 (factory() if callable(factory) else factory) 286 for factory in saveable_objects_from_trackable(var).values()] 287 names_to_saveables.update( 288 op_list_to_dict(trackable_saveables)) 289 else: 290 # Variables (reference and resource) have an _in_graph_mode property 291 # indicating whether they were created in a graph building context. We 292 # also get Tensors when graph building, which do not have this property. 293 if not getattr(var, "_in_graph_mode", True): 294 if not isinstance(var, resource_variable_ops.BaseResourceVariable): 295 raise ValueError( 296 "Can only save/restore ResourceVariables when eager execution " 297 f"is enabled. Got type: {type(var)}.") 298 set_var = names_to_saveables.setdefault(var._shared_name, var) 299 if set_var is not var: 300 raise ValueError( 301 "Two different ResourceVariable objects with the same " 302 f"shared_name '{var._shared_name}' were passed to the Saver. This" 303 " likely means that they were created in different Graphs or " 304 "isolated contexts, and may not be checkpointed together.") 305 else: 306 if convert_variable_to_tensor: 307 if isinstance(var, resource_variable_ops.BaseResourceVariable): 308 var = var._graph_element # pylint: disable=protected-access 309 else: 310 var = ops.convert_to_tensor(var, as_ref=True) 311 if not _tensor_comes_from_variable(var): 312 raise TypeError(f"Variable to save is not a Variable: {var}") 313 if var.op.type == "ReadVariableOp": 314 name = var.op.inputs[0].op.name 315 else: 316 name = var.op.name 317 if name in names_to_saveables: 318 raise ValueError(f"At least two variables have the same name: {name}") 319 names_to_saveables[name] = var 320 321 # pylint: enable=protected-access 322 return names_to_saveables 323 324 325def _add_saveable(saveables, seen_ops, saveable): 326 """Adds the saveable to the saveables list. 327 328 Args: 329 saveables: List to append the SaveableObject to. 330 seen_ops: Set of the ops of the saveables already processed. Used to 331 check that each saveable is only saved once. 332 saveable: The saveable. 333 334 Raises: 335 ValueError: If the saveable has already been processed. 336 """ 337 if saveable.op is not None and saveable.op in seen_ops: 338 raise ValueError("The same saveable will be restored with two names: " 339 f"{saveable.name}") 340 saveables.append(saveable) 341 seen_ops.add(saveable.op) 342 343 344def validate_and_slice_inputs(names_to_saveables): 345 """Returns the variables and names that will be used for a Saver. 346 347 Args: 348 names_to_saveables: A dict (k, v) where k is the name of an operation and 349 v is an operation to save or a BaseSaverBuilder.Saver. 350 351 Returns: 352 A list of SaveableObjects. 353 354 Raises: 355 TypeError: If any of the keys are not strings or any of the 356 values are not one of Tensor or Variable or a trackable operation. 357 ValueError: If the same operation is given in more than one value 358 (this also applies to slices of SlicedVariables). 359 """ 360 if not isinstance(names_to_saveables, dict): 361 names_to_saveables = op_list_to_dict(names_to_saveables) 362 363 saveables = [] 364 seen_ops = object_identity.ObjectIdentitySet() 365 for name, op in sorted(names_to_saveables.items(), 366 # Avoid comparing ops, sort only by name. 367 key=lambda x: x[0]): 368 for converted_saveable_object in saveable_objects_for_op(op, name): 369 _add_saveable(saveables, seen_ops, converted_saveable_object) 370 return saveables 371 372 373def trace_save_restore_function_map(obj, factory_data_list): 374 """Traces all save and restore functions in the provided factory list. 375 376 Args: 377 obj: `Trackable` object. 378 factory_data_list: List of `_CheckpointFactoryData`. 379 380 Returns: 381 Dict mapping atttribute names to tuples of concrete save/restore functions. 382 """ 383 saveable_fns = {} 384 385 for factory_data in factory_data_list: 386 saveable_factory = factory_data.factory 387 attribute_name = factory_data.name 388 389 # If object revives as a resource (or TPU/Mirrored) variable, 390 # there is no need to trace the save and restore functions. 391 if (resource_variable_ops.is_resource_variable(obj) or 392 resource_variable_ops.is_resource_variable(saveable_factory) or 393 not callable(saveable_factory)): 394 continue 395 396 concrete_save, concrete_restore = ( 397 _trace_save_restore_functions(saveable_factory, obj)) 398 if not concrete_save: 399 continue 400 saveable_fns[attribute_name] = (concrete_save, concrete_restore) 401 return saveable_fns 402 403 404def _trace_save_restore_functions(saveable_factory, obj): 405 """Traces save and restore functions.""" 406 if is_factory_for_restored_saveable_object(saveable_factory): 407 return (saveable_factory.keywords["save_function"], 408 saveable_factory.keywords["restore_function"]) 409 410 saveables = [] # Store the saveables in a data structure accessible to both 411 # the save and restore functions. 412 413 @def_function.function( 414 input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 415 def save_fn(checkpoint_key): 416 maybe_saveable = saveable_factory(name=checkpoint_key) 417 if isinstance(maybe_saveable, saveable_object.SaveableObject): 418 maybe_saveable = [maybe_saveable] 419 saveables[:] = maybe_saveable 420 421 # Return list of all SaveSpecs created by the factory. 422 ret = [] 423 for saveable in saveables: 424 for spec in saveable.specs: 425 ret.append({"name": spec.name, "tensor": spec.tensor, 426 "slice_spec": spec.slice_spec}) 427 return ret 428 429 concrete_save = save_fn.get_concrete_function() 430 431 # The SaveableObjects are produced when `save_fn` is traced. 432 saveables = validate_saveables_for_saved_model(saveables, obj) 433 if not saveables: 434 return None, None 435 436 # Use the SaveSpecs to define the input signature of the restore function. 437 restored_type_specs = [] 438 tensor_structure = [] 439 for saveable in saveables: 440 saveable_tensor_structure = [] 441 tensor_structure.append(saveable_tensor_structure) 442 for spec in saveable.specs: 443 restored_type_specs.append(type_spec.type_spec_from_value(spec.tensor)) 444 saveable_tensor_structure.append(spec.name) 445 446 @def_function.function(input_signature=restored_type_specs) 447 def restore_fn(*restored_tensors): 448 structured_restored_tensors = nest.pack_sequence_as( 449 tensor_structure, restored_tensors) 450 for saveable, restored_tensors in zip(saveables, 451 structured_restored_tensors): 452 saveable.restore(restored_tensors, restored_shapes=None) 453 return 1 # Return dummy tensor 454 455 concrete_restore = restore_fn.get_concrete_function() 456 return concrete_save, concrete_restore 457 458 459def validate_saveables_for_saved_model(saveables, obj): 460 """Makes sure SaveableObjects are compatible with SavedModel.""" 461 if isinstance(obj, python_state.PythonState): 462 logging.warn( 463 f"Note that object {obj} stores python values into the checkpoint. " 464 "These values will not be restored when loading the SavedModel " 465 "into python.") 466 return [] 467 if any(isinstance(saveable, trackable.NoRestoreSaveable) 468 for saveable in saveables): 469 return [] 470 return saveables 471 472 473class RestoredSaveableObject(saveable_object.SaveableObject): 474 """SaveableObject restored from SavedModel using the traced save/restore.""" 475 476 def __init__(self, names_and_slices, save_function, restore_function, name): 477 self.save_function = save_function 478 self.restore_function = restore_function 479 480 if tensor_util.is_tf_type(name): 481 name_tensor = name 482 else: 483 with ops.init_scope(): 484 name_tensor = constant_op.constant(name) 485 tensors = save_function(name_tensor) 486 specs = [] 487 for (str_name, str_slice), tensor_info in zip(names_and_slices, tensors): 488 specs.append(saveable_object.SaveSpec(tensor_info["tensor"], str_slice, 489 name + str_name)) 490 super(RestoredSaveableObject, self).__init__(None, specs, name) 491 492 def restore(self, restored_tensors, restored_shapes): 493 del restored_shapes # unused 494 return self.restore_function( 495 *[restored_tensors[i] for i in range(len(self.specs))]) 496 497 498def recreate_saveable_objects(saveable_fn_by_name): 499 """Returns a dict of SaveableObject factories generated from loaded fns.""" 500 501 names_and_slices = [] 502 503 with ops.init_scope(): 504 for save_fn, _ in saveable_fn_by_name.values(): 505 for tensor_info in save_fn(""): 506 names_and_slices.append(( 507 _convert_to_string(tensor_info["name"]), 508 _convert_to_string(tensor_info["slice_spec"]))) 509 510 saveable_factories = {} 511 for name, (save_fn, restore_fn) in saveable_fn_by_name.items(): 512 saveable_factories[name] = functools.partial( 513 RestoredSaveableObject, 514 names_and_slices=names_and_slices, 515 save_function=save_fn, 516 restore_function=restore_fn) 517 return saveable_factories 518 519 520def create_saveable_object(name, key, factory, call_with_mapped_captures): 521 """Creates a SaveableObject while potentially in a different graph. 522 523 When creating the frozen saver for SavedModel, the save and restore ops are 524 placed in a separate graph. Since RestoredSaveableObject uses tf.functions to 525 save and restore, the function captures must be mapped to the new graph. 526 527 Args: 528 name: Name of SaveableObject factory. 529 key: Checkpoint key of this SaveableObject. 530 factory: Factory method for creating the SaveableObject. 531 call_with_mapped_captures: Helper that calls a tf.function while remapping 532 the captures. 533 534 Returns: 535 a SaveableObject. 536 """ 537 if call_with_mapped_captures is None: 538 return factory(name=key) 539 if name == trackable_utils.SERIALIZE_TO_TENSORS_NAME: 540 return factory(name=key, 541 call_with_mapped_captures=call_with_mapped_captures) 542 elif is_factory_for_restored_saveable_object(factory): 543 concrete_save_fn = factory.keywords["save_function"] 544 545 def save_fn(name): 546 return call_with_mapped_captures(concrete_save_fn, [name]) 547 548 concrete_restore_fn = factory.keywords["restore_function"] 549 550 def restore_fn(*restored_tensors): 551 return call_with_mapped_captures(concrete_restore_fn, restored_tensors) 552 553 return factory(save_function=save_fn, restore_function=restore_fn, 554 name=key) 555 else: 556 return factory(name=key) 557 558 559def is_factory_for_restored_saveable_object(factory): 560 return (isinstance(factory, functools.partial) and 561 factory.func is RestoredSaveableObject) 562 563 564@tf_export("__internal__.tracking.saveable_objects_from_trackable", v1=[]) 565def saveable_objects_from_trackable(obj): 566 """Returns SaveableObject factory dict from a Trackable.""" 567 if isinstance(obj, python_state.PythonState): 568 return { 569 "py_state": 570 functools.partial( 571 _PythonStringStateSaveable, 572 state_callback=obj.serialize, 573 restore_callback=obj.deserialize) 574 } 575 if trackable_has_serialize_to_tensor(obj): 576 577 def create_saveable(name="", call_with_mapped_captures=None): 578 return TrackableSaveable(obj, name, call_with_mapped_captures) 579 580 return {trackable_utils.SERIALIZE_TO_TENSORS_NAME: create_saveable} 581 else: 582 return obj._gather_saveables_for_checkpoint() # pylint: disable=protected-access 583 584 585class TrackableSaveable(saveable_object.SaveableObject): 586 """A SaveableObject that defines `Trackable` checkpointing steps.""" 587 588 def __init__(self, obj, name, call_with_mapped_captures=None): 589 self._trackable = obj 590 self._call_with_mapped_captures = call_with_mapped_captures 591 592 save_fn = obj._serialize_to_tensors # pylint: disable=protected-access 593 594 if (call_with_mapped_captures and 595 isinstance(save_fn, core.ConcreteFunction)): 596 tensor_dict = call_with_mapped_captures(save_fn, []) 597 else: 598 tensor_dict = save_fn() 599 600 specs = [] 601 self._local_names = [] 602 self._prefix = saveable_compat.get_saveable_name(self._trackable) or "" 603 for tensor_name, maybe_tensor in tensor_dict.items(): 604 self._local_names.append(tensor_name) 605 spec_name = name + trackable_utils.escape_local_name(tensor_name) 606 607 if not isinstance(maybe_tensor, dict): 608 maybe_tensor = {"": maybe_tensor} 609 610 # Create separate specs for each slice spec. 611 for slice_spec, tensor in maybe_tensor.items(): 612 specs.append(saveable_object.SaveSpec(tensor, slice_spec, spec_name)) 613 super(TrackableSaveable, self).__init__(obj, specs, name) 614 615 def restore(self, restored_tensors, restored_shapes): 616 del restored_shapes # Unused. 617 restored_tensor_dict = {} 618 for n, local_name in enumerate(self._local_names): 619 restored_tensor_dict[local_name] = restored_tensors[n] 620 621 def restore_from_tensors(): 622 restore_fn = self._trackable._restore_from_tensors # pylint: disable=protected-access 623 if (self._call_with_mapped_captures and 624 isinstance(restore_fn, core.ConcreteFunction)): 625 self._call_with_mapped_captures(restore_fn, [restored_tensor_dict]) 626 else: 627 restore_fn(restored_tensor_dict) 628 629 # In graph mode, this wrapper function is converted into a tf.function, 630 # and to ensure that _restore_from_tensors is executed, there must be at 631 # least one returned tensor. `_restore_from_tensors` may return zero 632 # tensors so create a dummy constant here. 633 return constant_op.constant(1) 634 635 if not ops.executing_eagerly_outside_functions(): 636 restore_from_tensors = def_function.function(restore_from_tensors) 637 return restore_from_tensors() 638 639 def get_proto_names_and_checkpoint_keys(self): 640 return [(self._prefix + local_name, spec.name) 641 for local_name, spec in zip(self._local_names, self.specs)] 642 643 644class _PythonStringStateSaveable(saveable_object.SaveableObject): 645 """Saves Python state in a checkpoint.""" 646 647 def __init__(self, name, state_callback, restore_callback): 648 """Configure saving. 649 650 Args: 651 name: The checkpoint key to write to. 652 state_callback: A function taking no arguments which returns a string. 653 This function is run every time a checkpoint is written. 654 restore_callback: A function taking a Python string, used to restore 655 state. 656 """ 657 658 def _state_callback_wrapper(): 659 with ops.init_scope(): 660 return state_callback() 661 662 self._state_callback = _state_callback_wrapper 663 self._restore_callback = restore_callback 664 with ops.device("/cpu:0"): 665 self._save_string = constant_op.constant("", dtype=dtypes.string) 666 spec = saveable_object.SaveSpec( 667 self._save_string, "", name, dtype=dtypes.string) 668 super(_PythonStringStateSaveable, self).__init__(self._save_string, [spec], 669 name) 670 671 def feed_dict_additions(self): 672 """When running a graph, indicates fresh state to feed.""" 673 return {self._save_string: self._state_callback()} 674 675 def freeze(self): 676 """Create a frozen `SaveableObject` which saves the current state.""" 677 678 def _constant_state(): 679 return constant_op.constant(self._state_callback(), dtype=dtypes.string) 680 681 return trackable.NoRestoreSaveable( 682 tensor=_constant_state, 683 dtype=dtypes.string, 684 name=self.name, 685 device="cpu:0") 686 687 688def trackable_has_serialize_to_tensor(obj): 689 # pylint: disable=protected-access 690 obj_serialize_fn = obj._serialize_to_tensors 691 if hasattr(obj_serialize_fn, "__func__"): 692 obj_serialize_fn = obj_serialize_fn.__func__ 693 return trackable.Trackable._serialize_to_tensors != obj_serialize_fn 694 # pylint: enable=protected-access 695 696 697def _convert_to_string(x): 698 return compat.as_str(tensor_util.constant_value(x)) 699 700 701class SaveableCompatibilityConverter(trackable.Trackable): 702 """Converts object's `SaveableObjects` to functions used in TF2 checkpointing. 703 704 A class that converts a Trackable object's `SaveableObjects` to save and 705 restore functions with the same signatures as 706 `Trackable._serialize_to_tensors` and `Trackable._restore_from_tensors`. 707 This class also produces a method for filling the object proto. 708 """ 709 710 __slots__ = ("_obj", "_cached_saveables") 711 712 def __init__(self, obj): 713 """Constructor. 714 715 Args: 716 obj: A Trackable object which implements the deprecated 717 `_gather_saveables_for_checkpoint`. 718 """ 719 self._obj = obj 720 self._cached_saveables = None 721 722 _ = self._saveables # Generate cached saveables when converter is created. 723 724 @property 725 def _saveables(self): 726 """Returns a list of SaveableObjects generated from the Trackable object.""" 727 if self._cached_saveables is not None: 728 return self._cached_saveables 729 730 self._cached_saveables = [] 731 saveable_names = [] 732 for name, saveable_factory in ( 733 saveable_objects_from_trackable(self._obj).items()): 734 if callable(saveable_factory): 735 maybe_saveable = create_saveable_object( 736 name, name, saveable_factory, call_with_mapped_captures=None) 737 else: 738 maybe_saveable = saveable_factory 739 if isinstance(maybe_saveable, saveable_object.SaveableObject): 740 saveables = (maybe_saveable,) 741 else: 742 saveables = tuple(saveable_objects_for_op(op=maybe_saveable, name=name)) 743 self._cached_saveables.extend(saveables) 744 saveable_names.extend([name] * len(saveables)) 745 746 if not saveable_compat.force_checkpoint_conversion_enabled(): 747 # Run an extra step to validate that the converter can be used without 748 # changing the checkpoint metadata. 749 self._maybe_apply_legacy_decorator(saveable_names) 750 751 return self._cached_saveables 752 753 def _maybe_apply_legacy_decorator(self, saveable_names): 754 # Check the spec names. If there are multiple specs with different names 755 # under the same saveable, then the this indicates that a decorator must be 756 # used to ensure checkpoint equality under the new checkpoint 757 # implementation. See the docstring `legacy_saveable_name` for details. 758 for saveable in self._cached_saveables: 759 spec_names = set(spec for spec in saveable.specs) 760 761 if len(spec_names) == 1: 762 continue # Decorator not needed. 763 764 if len(set(saveable_names)) > 1: 765 # An edge case not handled by the legacy decorator has been encountered. 766 raise saveable_compat.CheckpointConversionError 767 768 saveable_compat.legacy_saveable_name(saveable_names[0])(self) 769 770 def _serialize_to_tensors(self): 771 """Returns a dict of tensors to serialize.""" 772 return saveable_object_to_tensor_dict(self._saveables) 773 774 def _restore_from_tensors(self, restored_tensors): 775 """Returns the restore ops defined in the Saveables.""" 776 # Map restored tensors to the corresponding SaveableObjects, then call 777 # restore. There must be an exact match between restored tensors and the 778 # expected attributes. 779 expected_keys = [] 780 for saveable in self._saveables: 781 expected_keys.extend(spec.name for spec in saveable.specs) 782 if set(expected_keys) != restored_tensors.keys(): 783 raise ValueError(f"Could not restore object {self._obj} because not all " 784 "expected tensors were in the checkpoint." 785 f"\n\tExpected: {expected_keys}" 786 f"\n\tGot: {list(restored_tensors.keys())}") 787 788 return saveable_object_to_restore_fn(self._saveables)(restored_tensors) 789 790 791def saveable_object_to_tensor_dict(saveables): 792 """Converts a list of SaveableObjects to a tensor dictionary.""" 793 tensor_dict = {} 794 for saveable in saveables: 795 for spec in saveable.specs: 796 name = _convert_to_string(spec.name) 797 slice_spec = _convert_to_string(spec.slice_spec) 798 # Currently, tensor dict cannot handle callable tensor values (which 799 # are needed for uninitialized variables), so keep using SaveSpec. 800 tensor = spec if callable(spec._tensor) else spec._tensor # pylint: disable=protected-access 801 if slice_spec: 802 tensor_dict.setdefault(name, {})[slice_spec] = tensor 803 else: 804 tensor_dict[name] = tensor 805 return tensor_dict 806 807 808def saveable_object_to_restore_fn(saveables): 809 """Generates `Trackable._restore_from_tensors` from SaveableObjects.""" 810 811 def _restore_from_tensors(restored_tensors): 812 restore_ops = {} 813 814 for saveable in saveables: 815 saveable_restored_tensors = [] 816 for spec in saveable.specs: 817 name = _convert_to_string(spec.name) 818 slice_spec = _convert_to_string(spec.slice_spec) 819 820 maybe_tensor = restored_tensors[name] 821 if not isinstance(maybe_tensor, dict): 822 maybe_tensor = {"": maybe_tensor} 823 824 saveable_restored_tensors.append(maybe_tensor[slice_spec]) 825 restore_ops[saveable.name] = saveable.restore( 826 saveable_restored_tensors, restored_shapes=None) 827 return restore_ops 828 829 return _restore_from_tensors 830