• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for working with and creating SaveableObjects."""
16import functools
17
18from tensorflow.python.checkpoint import saveable_compat
19from tensorflow.python.eager import context
20from tensorflow.python.eager import def_function
21
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import device as pydev
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_spec
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.framework import type_spec
29
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import resource_variable_ops
32from tensorflow.python.ops import state_ops
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.trackable import base as trackable
36from tensorflow.python.trackable import python_state
37from tensorflow.python.trackable import trackable_utils
38from tensorflow.python.training.saving import saveable_object
39from tensorflow.python.types import core
40from tensorflow.python.util import compat
41from tensorflow.python.util import nest
42from tensorflow.python.util import object_identity
43from tensorflow.python.util.tf_export import tf_export
44
45# Op names which identify variable reads which should be saved.
46_VARIABLE_OPS = set(["Variable",
47                     "VariableV2",
48                     "AutoReloadVariable",
49                     "VarHandleOp",
50                     "ReadVariableOp"])
51
52
53def set_cpu0(device_string):
54  """Creates a new device string based on `device_string` but using /CPU:0.
55
56  If the device is already on /CPU:0, this is a no-op.
57
58  Args:
59    device_string: A device string.
60
61  Returns:
62    A device string.
63  """
64  parsed_device = pydev.DeviceSpec.from_string(device_string)
65  parsed_device = parsed_device.replace(device_type="CPU", device_index=0)
66  return parsed_device.to_string()
67
68
69class ReferenceVariableSaveable(saveable_object.SaveableObject):
70  """SaveableObject implementation that handles reference variables."""
71
72  def __init__(self, var, slice_spec, name):
73    spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype)
74    super(ReferenceVariableSaveable, self).__init__(var, [spec], name)
75
76  def restore(self, restored_tensors, restored_shapes):
77    restored_tensor = restored_tensors[0]
78    if restored_shapes is not None:
79      restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
80    return state_ops.assign(
81        self.op,
82        restored_tensor,
83        validate_shape=restored_shapes is None and
84        self.op.get_shape().is_fully_defined())
85
86
87class ResourceVariableSaveable(saveable_object.SaveableObject):
88  """SaveableObject implementation that handles ResourceVariables."""
89
90  def __init__(self, var, slice_spec, name):
91    self._var_device = var.device
92    self._var_shape = var.shape
93    if isinstance(var, ops.Tensor):
94      self.handle_op = var.op.inputs[0]
95      tensor = var
96    elif resource_variable_ops.is_resource_variable(var):
97
98      def _read_variable_closure(v):
99        def f():
100          with ops.device(v.device):
101            if context.executing_eagerly() and not v.is_initialized():
102              # A SaveSpec tensor value of `None` indicates that the variable is
103              # uninitialized.
104              return None
105            # Read the variable without making a copy to limit memory usage.
106            x = v.read_value_no_copy()
107            # To allow variables placed on non-CPU devices to be checkpointed,
108            # we copy them to CPU on the same machine first.
109            with ops.device("/device:CPU:0"):
110              return array_ops.identity(x)
111
112        return f
113
114      self.handle_op = var.handle
115      tensor = _read_variable_closure(var)
116    else:
117      raise ValueError(
118          "Saveable is neither a resource variable nor a read operation."
119          f" Got: {repr(var)}")
120    spec = saveable_object.SaveSpec(tensor, slice_spec, name,
121                                    dtype=var.dtype, device=var.device)
122    super(ResourceVariableSaveable, self).__init__(var, [spec], name)
123
124  def restore(self, restored_tensors, restored_shapes):
125    """Restores tensors. Raises ValueError if incompatible shape found."""
126    restored_tensor = restored_tensors[0]
127    if restored_shapes is not None:
128      restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
129    # Copy the restored tensor to the variable's device.
130    with ops.device(self._var_device):
131      restored_tensor = array_ops.identity(restored_tensor)
132      try:
133        assigned_variable = resource_variable_ops.shape_safe_assign_variable_handle(
134            self.handle_op, self._var_shape, restored_tensor)
135      except ValueError as e:
136        raise ValueError(
137            f"Received incompatible tensor with shape {restored_tensor.shape} "
138            f"when attempting to restore variable with shape {self._var_shape} "
139            f"and name {self.name}.") from e
140      return assigned_variable
141
142
143def _tensor_comes_from_variable(v):
144  return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS
145
146
147def saveable_objects_for_op(op, name):
148  """Create `SaveableObject`s from an operation.
149
150  Args:
151    op: A variable, operation, or SaveableObject to coerce into a
152      SaveableObject.
153    name: A string name for the SaveableObject.
154
155  Yields:
156    `SaveableObject`s which together save/restore `op`.
157
158  Raises:
159    TypeError: If `name` is not a string.
160    ValueError: For operations with no known conversion to SaveableObject.
161  """
162  if not isinstance(name, str):
163    raise TypeError(
164        "names_to_saveables must be a dict mapping string names to "
165        f"trackable operations. Name is not a string: {name}")
166  if isinstance(op, saveable_object.SaveableObject):
167    yield op
168  elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
169    if isinstance(op, variables.PartitionedVariable):
170      op = list(op)
171    # A set of slices.
172    slice_name = None
173    # pylint: disable=protected-access
174    for variable in op:
175      if isinstance(variable, saveable_object.SaveableObject):
176        yield variable
177        continue
178      if not isinstance(variable, variables.Variable):
179        raise ValueError(f"Slices must all be Variables: {variable}")
180      if not variable._save_slice_info:
181        raise ValueError(f"Slices must all be slices: {variable}")
182      if slice_name is None:
183        slice_name = variable._save_slice_info.full_name
184      elif slice_name != variable._save_slice_info.full_name:
185        raise ValueError(
186            f"Slices must all be from the same tensor: {slice_name} != "
187            f"{variable._save_slice_info.full_name}")
188      if variable.op.type in ["Variable", "VariableV2",
189                              "AutoReloadVariable"]:
190        yield ReferenceVariableSaveable(
191            variable, variable._save_slice_info.spec, name)
192      else:
193        yield ResourceVariableSaveable(variable, variable._save_slice_info.spec,
194                                       name)
195    # pylint: enable=protected-access
196  elif isinstance(op, trackable.Trackable) and not isinstance(
197      op, variables.Variable):
198    # pylint: disable=protected-access
199    for attr, factory in saveable_objects_from_trackable(op).items():
200      if attr == trackable.VARIABLE_VALUE_KEY:
201        # Keep original name for classes masquerading as variables.
202        full_name = name
203      else:
204        full_name = name + "_" + attr
205      op = (factory(full_name) if callable(factory) else factory)
206      for op in saveable_objects_for_op(op, op.name):
207        yield op
208    # pylint: enable=protected-access
209  else:
210    # A variable or tensor.
211    if isinstance(op, resource_variable_ops.BaseResourceVariable):
212      if op._in_graph_mode:  # pylint: disable=protected-access
213        variable = op._graph_element  # pylint: disable=protected-access
214      else:
215        variable = op
216      yield ResourceVariableSaveable(variable, "", name)
217    else:
218      if context.executing_eagerly():
219        raise ValueError("Can only save/restore ResourceVariables when "
220                         f"executing eagerly, got type: {type(op)}.")
221
222      variable = ops.convert_to_tensor(op, as_ref=True)
223      if not _tensor_comes_from_variable(variable):
224        raise TypeError(
225            "names_to_saveables must be a dict mapping string "
226            f"names to Tensors/Variables. Not a variable: {variable}")
227      if variable.op.type in ["Variable", "VariableV2",
228                              "AutoReloadVariable"]:
229        yield ReferenceVariableSaveable(variable, "", name)
230      else:
231        yield ResourceVariableSaveable(variable, "", name)
232
233
234def op_list_to_dict(op_list, convert_variable_to_tensor=True):
235  """Create a dictionary of names to operation lists.
236
237  Args:
238    op_list: A (nested) list, tuple, or set of Variables or SaveableObjects.
239    convert_variable_to_tensor: Whether or not to convert single Variables
240      with no slice info into Tensors.
241
242  Returns:
243    A dictionary of names to the operations that must be saved under
244    that name.  Variables with save_slice_info are grouped together under the
245    same key in no particular order.
246
247  Raises:
248    TypeError: If the type of op_list or its elements is not supported.
249    ValueError: If at least two saveables share the same name.
250  """
251  if not isinstance(op_list, (list, tuple, set)):
252    raise TypeError("Variables to save should be passed in a dict or a "
253                    f"list. Got {op_list}")
254  # List casting is necessary to support sets.
255  op_list = nest.flatten(list(op_list))
256  # When ResourceVariables are converted to Tensors, read ops are added to the
257  # graph. Sorting the op_list ensures that the resulting graph is always
258  # constructed in a deterministic way:
259  op_list = sorted(op_list, key=lambda x: x.name)
260  names_to_saveables = {}
261  # pylint: disable=protected-access
262  for var in op_list:
263    resource_or_ref_variable = (
264        isinstance(var, resource_variable_ops.BaseResourceVariable) or
265        isinstance(var, variables.RefVariable))
266
267    if isinstance(var, saveable_object.SaveableObject):
268      names_to_saveables[var.name] = var
269    elif isinstance(var, variables.PartitionedVariable):
270      if var.name in names_to_saveables:
271        raise ValueError(
272            f"At least two variables have the same name: {var.name}")
273      names_to_saveables[var.name] = var
274    elif isinstance(var, variables.Variable) and var._save_slice_info:
275      name = var._save_slice_info.full_name
276      if name in names_to_saveables:
277        if not isinstance(names_to_saveables[name], list):
278          raise ValueError("Mixing slices and non-slices with the same name: "
279                           f"{name}")
280        names_to_saveables[name].append(var)
281      else:
282        names_to_saveables[name] = [var]
283    elif isinstance(var, trackable.Trackable) and not resource_or_ref_variable:
284      trackable_saveables = [
285          (factory() if callable(factory) else factory)
286          for factory in saveable_objects_from_trackable(var).values()]
287      names_to_saveables.update(
288          op_list_to_dict(trackable_saveables))
289    else:
290      # Variables (reference and resource) have an _in_graph_mode property
291      # indicating whether they were created in a graph building context. We
292      # also get Tensors when graph building, which do not have this property.
293      if not getattr(var, "_in_graph_mode", True):
294        if not isinstance(var, resource_variable_ops.BaseResourceVariable):
295          raise ValueError(
296              "Can only save/restore ResourceVariables when eager execution "
297              f"is enabled. Got type: {type(var)}.")
298        set_var = names_to_saveables.setdefault(var._shared_name, var)
299        if set_var is not var:
300          raise ValueError(
301              "Two different ResourceVariable objects with the same "
302              f"shared_name '{var._shared_name}' were passed to the Saver. This"
303              " likely means that they were created in different Graphs or "
304              "isolated contexts, and may not be checkpointed together.")
305      else:
306        if convert_variable_to_tensor:
307          if isinstance(var, resource_variable_ops.BaseResourceVariable):
308            var = var._graph_element  # pylint: disable=protected-access
309          else:
310            var = ops.convert_to_tensor(var, as_ref=True)
311          if not _tensor_comes_from_variable(var):
312            raise TypeError(f"Variable to save is not a Variable: {var}")
313        if var.op.type == "ReadVariableOp":
314          name = var.op.inputs[0].op.name
315        else:
316          name = var.op.name
317        if name in names_to_saveables:
318          raise ValueError(f"At least two variables have the same name: {name}")
319        names_to_saveables[name] = var
320
321    # pylint: enable=protected-access
322  return names_to_saveables
323
324
325def _add_saveable(saveables, seen_ops, saveable):
326  """Adds the saveable to the saveables list.
327
328  Args:
329    saveables: List to append the SaveableObject to.
330    seen_ops: Set of the ops of the saveables already processed.  Used to
331      check that each saveable is only saved once.
332    saveable: The saveable.
333
334  Raises:
335    ValueError: If the saveable has already been processed.
336  """
337  if saveable.op is not None and saveable.op in seen_ops:
338    raise ValueError("The same saveable will be restored with two names: "
339                     f"{saveable.name}")
340  saveables.append(saveable)
341  seen_ops.add(saveable.op)
342
343
344def validate_and_slice_inputs(names_to_saveables):
345  """Returns the variables and names that will be used for a Saver.
346
347  Args:
348    names_to_saveables: A dict (k, v) where k is the name of an operation and
349       v is an operation to save or a BaseSaverBuilder.Saver.
350
351  Returns:
352    A list of SaveableObjects.
353
354  Raises:
355    TypeError: If any of the keys are not strings or any of the
356      values are not one of Tensor or Variable or a trackable operation.
357    ValueError: If the same operation is given in more than one value
358      (this also applies to slices of SlicedVariables).
359  """
360  if not isinstance(names_to_saveables, dict):
361    names_to_saveables = op_list_to_dict(names_to_saveables)
362
363  saveables = []
364  seen_ops = object_identity.ObjectIdentitySet()
365  for name, op in sorted(names_to_saveables.items(),
366                         # Avoid comparing ops, sort only by name.
367                         key=lambda x: x[0]):
368    for converted_saveable_object in saveable_objects_for_op(op, name):
369      _add_saveable(saveables, seen_ops, converted_saveable_object)
370  return saveables
371
372
373def trace_save_restore_function_map(obj, factory_data_list):
374  """Traces all save and restore functions in the provided factory list.
375
376  Args:
377    obj: `Trackable` object.
378    factory_data_list: List of `_CheckpointFactoryData`.
379
380  Returns:
381    Dict mapping atttribute names to tuples of concrete save/restore functions.
382  """
383  saveable_fns = {}
384
385  for factory_data in factory_data_list:
386    saveable_factory = factory_data.factory
387    attribute_name = factory_data.name
388
389    # If object revives as a resource (or TPU/Mirrored) variable,
390    # there is no need to trace the save and restore functions.
391    if (resource_variable_ops.is_resource_variable(obj) or
392        resource_variable_ops.is_resource_variable(saveable_factory) or
393        not callable(saveable_factory)):
394      continue
395
396    concrete_save, concrete_restore = (
397        _trace_save_restore_functions(saveable_factory, obj))
398    if not concrete_save:
399      continue
400    saveable_fns[attribute_name] = (concrete_save, concrete_restore)
401  return saveable_fns
402
403
404def _trace_save_restore_functions(saveable_factory, obj):
405  """Traces save and restore functions."""
406  if is_factory_for_restored_saveable_object(saveable_factory):
407    return (saveable_factory.keywords["save_function"],
408            saveable_factory.keywords["restore_function"])
409
410  saveables = []  # Store the saveables in a data structure accessible to both
411  # the save and restore functions.
412
413  @def_function.function(
414      input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
415  def save_fn(checkpoint_key):
416    maybe_saveable = saveable_factory(name=checkpoint_key)
417    if isinstance(maybe_saveable, saveable_object.SaveableObject):
418      maybe_saveable = [maybe_saveable]
419    saveables[:] = maybe_saveable
420
421    # Return list of all SaveSpecs created by the factory.
422    ret = []
423    for saveable in saveables:
424      for spec in saveable.specs:
425        ret.append({"name": spec.name, "tensor": spec.tensor,
426                    "slice_spec": spec.slice_spec})
427    return ret
428
429  concrete_save = save_fn.get_concrete_function()
430
431  # The SaveableObjects are produced when `save_fn` is traced.
432  saveables = validate_saveables_for_saved_model(saveables, obj)
433  if not saveables:
434    return None, None
435
436  # Use the SaveSpecs to define the input signature of the restore function.
437  restored_type_specs = []
438  tensor_structure = []
439  for saveable in saveables:
440    saveable_tensor_structure = []
441    tensor_structure.append(saveable_tensor_structure)
442    for spec in saveable.specs:
443      restored_type_specs.append(type_spec.type_spec_from_value(spec.tensor))
444      saveable_tensor_structure.append(spec.name)
445
446  @def_function.function(input_signature=restored_type_specs)
447  def restore_fn(*restored_tensors):
448    structured_restored_tensors = nest.pack_sequence_as(
449        tensor_structure, restored_tensors)
450    for saveable, restored_tensors in zip(saveables,
451                                          structured_restored_tensors):
452      saveable.restore(restored_tensors, restored_shapes=None)
453    return 1  # Return dummy tensor
454
455  concrete_restore = restore_fn.get_concrete_function()
456  return concrete_save, concrete_restore
457
458
459def validate_saveables_for_saved_model(saveables, obj):
460  """Makes sure SaveableObjects are compatible with SavedModel."""
461  if isinstance(obj, python_state.PythonState):
462    logging.warn(
463        f"Note that object {obj} stores python values into the checkpoint. "
464        "These values will not be restored when loading the SavedModel "
465        "into python.")
466    return []
467  if any(isinstance(saveable, trackable.NoRestoreSaveable)
468         for saveable in saveables):
469    return []
470  return saveables
471
472
473class RestoredSaveableObject(saveable_object.SaveableObject):
474  """SaveableObject restored from SavedModel using the traced save/restore."""
475
476  def __init__(self, names_and_slices, save_function, restore_function, name):
477    self.save_function = save_function
478    self.restore_function = restore_function
479
480    if tensor_util.is_tf_type(name):
481      name_tensor = name
482    else:
483      with ops.init_scope():
484        name_tensor = constant_op.constant(name)
485    tensors = save_function(name_tensor)
486    specs = []
487    for (str_name, str_slice), tensor_info in zip(names_and_slices, tensors):
488      specs.append(saveable_object.SaveSpec(tensor_info["tensor"], str_slice,
489                                            name + str_name))
490    super(RestoredSaveableObject, self).__init__(None, specs, name)
491
492  def restore(self, restored_tensors, restored_shapes):
493    del restored_shapes  # unused
494    return self.restore_function(
495        *[restored_tensors[i] for i in range(len(self.specs))])
496
497
498def recreate_saveable_objects(saveable_fn_by_name):
499  """Returns a dict of SaveableObject factories generated from loaded fns."""
500
501  names_and_slices = []
502
503  with ops.init_scope():
504    for save_fn, _ in saveable_fn_by_name.values():
505      for tensor_info in save_fn(""):
506        names_and_slices.append((
507            _convert_to_string(tensor_info["name"]),
508            _convert_to_string(tensor_info["slice_spec"])))
509
510  saveable_factories = {}
511  for name, (save_fn, restore_fn) in saveable_fn_by_name.items():
512    saveable_factories[name] = functools.partial(
513        RestoredSaveableObject,
514        names_and_slices=names_and_slices,
515        save_function=save_fn,
516        restore_function=restore_fn)
517  return saveable_factories
518
519
520def create_saveable_object(name, key, factory, call_with_mapped_captures):
521  """Creates a SaveableObject while potentially in a different graph.
522
523  When creating the frozen saver for SavedModel, the save and restore ops are
524  placed in a separate graph. Since RestoredSaveableObject uses tf.functions to
525  save and restore, the function captures must be mapped to the new graph.
526
527  Args:
528    name: Name of SaveableObject factory.
529    key: Checkpoint key of this SaveableObject.
530    factory: Factory method for creating the SaveableObject.
531    call_with_mapped_captures: Helper that calls a tf.function while remapping
532      the captures.
533
534  Returns:
535    a SaveableObject.
536  """
537  if call_with_mapped_captures is None:
538    return factory(name=key)
539  if name == trackable_utils.SERIALIZE_TO_TENSORS_NAME:
540    return factory(name=key,
541                   call_with_mapped_captures=call_with_mapped_captures)
542  elif is_factory_for_restored_saveable_object(factory):
543    concrete_save_fn = factory.keywords["save_function"]
544
545    def save_fn(name):
546      return call_with_mapped_captures(concrete_save_fn, [name])
547
548    concrete_restore_fn = factory.keywords["restore_function"]
549
550    def restore_fn(*restored_tensors):
551      return call_with_mapped_captures(concrete_restore_fn, restored_tensors)
552
553    return factory(save_function=save_fn, restore_function=restore_fn,
554                   name=key)
555  else:
556    return factory(name=key)
557
558
559def is_factory_for_restored_saveable_object(factory):
560  return (isinstance(factory, functools.partial) and
561          factory.func is RestoredSaveableObject)
562
563
564@tf_export("__internal__.tracking.saveable_objects_from_trackable", v1=[])
565def saveable_objects_from_trackable(obj):
566  """Returns SaveableObject factory dict from a Trackable."""
567  if isinstance(obj, python_state.PythonState):
568    return {
569        "py_state":
570            functools.partial(
571                _PythonStringStateSaveable,
572                state_callback=obj.serialize,
573                restore_callback=obj.deserialize)
574    }
575  if trackable_has_serialize_to_tensor(obj):
576
577    def create_saveable(name="", call_with_mapped_captures=None):
578      return TrackableSaveable(obj, name, call_with_mapped_captures)
579
580    return {trackable_utils.SERIALIZE_TO_TENSORS_NAME: create_saveable}
581  else:
582    return obj._gather_saveables_for_checkpoint()  # pylint: disable=protected-access
583
584
585class TrackableSaveable(saveable_object.SaveableObject):
586  """A SaveableObject that defines `Trackable` checkpointing steps."""
587
588  def __init__(self, obj, name, call_with_mapped_captures=None):
589    self._trackable = obj
590    self._call_with_mapped_captures = call_with_mapped_captures
591
592    save_fn = obj._serialize_to_tensors  # pylint: disable=protected-access
593
594    if (call_with_mapped_captures and
595        isinstance(save_fn, core.ConcreteFunction)):
596      tensor_dict = call_with_mapped_captures(save_fn, [])
597    else:
598      tensor_dict = save_fn()
599
600    specs = []
601    self._local_names = []
602    self._prefix = saveable_compat.get_saveable_name(self._trackable) or ""
603    for tensor_name, maybe_tensor in tensor_dict.items():
604      self._local_names.append(tensor_name)
605      spec_name = name + trackable_utils.escape_local_name(tensor_name)
606
607      if not isinstance(maybe_tensor, dict):
608        maybe_tensor = {"": maybe_tensor}
609
610      # Create separate specs for each slice spec.
611      for slice_spec, tensor in maybe_tensor.items():
612        specs.append(saveable_object.SaveSpec(tensor, slice_spec, spec_name))
613    super(TrackableSaveable, self).__init__(obj, specs, name)
614
615  def restore(self, restored_tensors, restored_shapes):
616    del restored_shapes  # Unused.
617    restored_tensor_dict = {}
618    for n, local_name in enumerate(self._local_names):
619      restored_tensor_dict[local_name] = restored_tensors[n]
620
621    def restore_from_tensors():
622      restore_fn = self._trackable._restore_from_tensors  # pylint: disable=protected-access
623      if (self._call_with_mapped_captures and
624          isinstance(restore_fn, core.ConcreteFunction)):
625        self._call_with_mapped_captures(restore_fn, [restored_tensor_dict])
626      else:
627        restore_fn(restored_tensor_dict)
628
629      # In graph mode, this wrapper function is converted into a tf.function,
630      # and to ensure that _restore_from_tensors is executed, there must be at
631      # least one returned tensor. `_restore_from_tensors` may return zero
632      # tensors so create a dummy constant here.
633      return constant_op.constant(1)
634
635    if not ops.executing_eagerly_outside_functions():
636      restore_from_tensors = def_function.function(restore_from_tensors)
637    return restore_from_tensors()
638
639  def get_proto_names_and_checkpoint_keys(self):
640    return [(self._prefix + local_name, spec.name)
641            for local_name, spec in zip(self._local_names, self.specs)]
642
643
644class _PythonStringStateSaveable(saveable_object.SaveableObject):
645  """Saves Python state in a checkpoint."""
646
647  def __init__(self, name, state_callback, restore_callback):
648    """Configure saving.
649
650    Args:
651      name: The checkpoint key to write to.
652      state_callback: A function taking no arguments which returns a string.
653        This function is run every time a checkpoint is written.
654      restore_callback: A function taking a Python string, used to restore
655        state.
656    """
657
658    def _state_callback_wrapper():
659      with ops.init_scope():
660        return state_callback()
661
662    self._state_callback = _state_callback_wrapper
663    self._restore_callback = restore_callback
664    with ops.device("/cpu:0"):
665      self._save_string = constant_op.constant("", dtype=dtypes.string)
666    spec = saveable_object.SaveSpec(
667        self._save_string, "", name, dtype=dtypes.string)
668    super(_PythonStringStateSaveable, self).__init__(self._save_string, [spec],
669                                                     name)
670
671  def feed_dict_additions(self):
672    """When running a graph, indicates fresh state to feed."""
673    return {self._save_string: self._state_callback()}
674
675  def freeze(self):
676    """Create a frozen `SaveableObject` which saves the current state."""
677
678    def _constant_state():
679      return constant_op.constant(self._state_callback(), dtype=dtypes.string)
680
681    return trackable.NoRestoreSaveable(
682        tensor=_constant_state,
683        dtype=dtypes.string,
684        name=self.name,
685        device="cpu:0")
686
687
688def trackable_has_serialize_to_tensor(obj):
689  # pylint: disable=protected-access
690  obj_serialize_fn = obj._serialize_to_tensors
691  if hasattr(obj_serialize_fn, "__func__"):
692    obj_serialize_fn = obj_serialize_fn.__func__
693  return trackable.Trackable._serialize_to_tensors != obj_serialize_fn
694  # pylint: enable=protected-access
695
696
697def _convert_to_string(x):
698  return compat.as_str(tensor_util.constant_value(x))
699
700
701class SaveableCompatibilityConverter(trackable.Trackable):
702  """Converts object's `SaveableObjects` to functions used in TF2 checkpointing.
703
704  A class that converts a Trackable object's `SaveableObjects` to save and
705  restore functions with the same signatures as
706  `Trackable._serialize_to_tensors` and `Trackable._restore_from_tensors`.
707  This class also produces a method for filling the object proto.
708  """
709
710  __slots__ = ("_obj", "_cached_saveables")
711
712  def __init__(self, obj):
713    """Constructor.
714
715    Args:
716      obj: A Trackable object which implements the deprecated
717        `_gather_saveables_for_checkpoint`.
718    """
719    self._obj = obj
720    self._cached_saveables = None
721
722    _ = self._saveables  # Generate cached saveables when converter is created.
723
724  @property
725  def _saveables(self):
726    """Returns a list of SaveableObjects generated from the Trackable object."""
727    if self._cached_saveables is not None:
728      return self._cached_saveables
729
730    self._cached_saveables = []
731    saveable_names = []
732    for name, saveable_factory in (
733        saveable_objects_from_trackable(self._obj).items()):
734      if callable(saveable_factory):
735        maybe_saveable = create_saveable_object(
736            name, name, saveable_factory, call_with_mapped_captures=None)
737      else:
738        maybe_saveable = saveable_factory
739      if isinstance(maybe_saveable, saveable_object.SaveableObject):
740        saveables = (maybe_saveable,)
741      else:
742        saveables = tuple(saveable_objects_for_op(op=maybe_saveable, name=name))
743      self._cached_saveables.extend(saveables)
744      saveable_names.extend([name] * len(saveables))
745
746    if not saveable_compat.force_checkpoint_conversion_enabled():
747      # Run an extra step to validate that the converter can be used without
748      # changing the checkpoint metadata.
749      self._maybe_apply_legacy_decorator(saveable_names)
750
751    return self._cached_saveables
752
753  def _maybe_apply_legacy_decorator(self, saveable_names):
754    # Check the spec names. If there are multiple specs with different names
755    # under the same saveable, then the this indicates that a decorator must be
756    # used to ensure checkpoint equality under the new checkpoint
757    # implementation. See the docstring `legacy_saveable_name` for details.
758    for saveable in self._cached_saveables:
759      spec_names = set(spec for spec in saveable.specs)
760
761      if len(spec_names) == 1:
762        continue  # Decorator not needed.
763
764      if len(set(saveable_names)) > 1:
765        # An edge case not handled by the legacy decorator has been encountered.
766        raise saveable_compat.CheckpointConversionError
767
768      saveable_compat.legacy_saveable_name(saveable_names[0])(self)
769
770  def _serialize_to_tensors(self):
771    """Returns a dict of tensors to serialize."""
772    return saveable_object_to_tensor_dict(self._saveables)
773
774  def _restore_from_tensors(self, restored_tensors):
775    """Returns the restore ops defined in the Saveables."""
776    # Map restored tensors to the corresponding SaveableObjects, then call
777    # restore. There must be an exact match between restored tensors and the
778    # expected attributes.
779    expected_keys = []
780    for saveable in self._saveables:
781      expected_keys.extend(spec.name for spec in saveable.specs)
782    if set(expected_keys) != restored_tensors.keys():
783      raise ValueError(f"Could not restore object {self._obj} because not all "
784                       "expected tensors were in the checkpoint."
785                       f"\n\tExpected: {expected_keys}"
786                       f"\n\tGot: {list(restored_tensors.keys())}")
787
788    return saveable_object_to_restore_fn(self._saveables)(restored_tensors)
789
790
791def saveable_object_to_tensor_dict(saveables):
792  """Converts a list of SaveableObjects to a tensor dictionary."""
793  tensor_dict = {}
794  for saveable in saveables:
795    for spec in saveable.specs:
796      name = _convert_to_string(spec.name)
797      slice_spec = _convert_to_string(spec.slice_spec)
798      # Currently, tensor dict cannot handle callable tensor values (which
799      # are needed for uninitialized variables), so keep using SaveSpec.
800      tensor = spec if callable(spec._tensor) else spec._tensor  # pylint: disable=protected-access
801      if slice_spec:
802        tensor_dict.setdefault(name, {})[slice_spec] = tensor
803      else:
804        tensor_dict[name] = tensor
805  return tensor_dict
806
807
808def saveable_object_to_restore_fn(saveables):
809  """Generates `Trackable._restore_from_tensors` from SaveableObjects."""
810
811  def _restore_from_tensors(restored_tensors):
812    restore_ops = {}
813
814    for saveable in saveables:
815      saveable_restored_tensors = []
816      for spec in saveable.specs:
817        name = _convert_to_string(spec.name)
818        slice_spec = _convert_to_string(spec.slice_spec)
819
820        maybe_tensor = restored_tensors[name]
821        if not isinstance(maybe_tensor, dict):
822          maybe_tensor = {"": maybe_tensor}
823
824        saveable_restored_tensors.append(maybe_tensor[slice_spec])
825      restore_ops[saveable.name] = saveable.restore(
826          saveable_restored_tensors, restored_shapes=None)
827    return restore_ops
828
829  return _restore_from_tensors
830