• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for working with and creating SaveableObjects."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import six
21
22from tensorflow.python.eager import context
23from tensorflow.python.framework import device as pydev
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import resource_variable_ops
27from tensorflow.python.ops import state_ops
28from tensorflow.python.ops import variables
29from tensorflow.python.training.saving import saveable_object
30from tensorflow.python.training.tracking import base as trackable
31from tensorflow.python.util import nest
32from tensorflow.python.util import object_identity
33
34
35# Op names which identify variable reads which should be saved.
36_VARIABLE_OPS = set(["Variable",
37                     "VariableV2",
38                     "AutoReloadVariable",
39                     "VarHandleOp",
40                     "ReadVariableOp"])
41
42
43def set_cpu0(device_string):
44  """Creates a new device string based on `device_string` but using /CPU:0.
45
46  If the device is already on /CPU:0, this is a no-op.
47
48  Args:
49    device_string: A device string.
50
51  Returns:
52    A device string.
53  """
54  parsed_device = pydev.DeviceSpec.from_string(device_string)
55  parsed_device = parsed_device.replace(device_type="CPU", device_index=0)
56  return parsed_device.to_string()
57
58
59class ReferenceVariableSaveable(saveable_object.SaveableObject):
60  """SaveableObject implementation that handles reference variables."""
61
62  def __init__(self, var, slice_spec, name):
63    spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype)
64    super(ReferenceVariableSaveable, self).__init__(var, [spec], name)
65
66  def restore(self, restored_tensors, restored_shapes):
67    restored_tensor = restored_tensors[0]
68    if restored_shapes is not None:
69      restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
70    return state_ops.assign(
71        self.op,
72        restored_tensor,
73        validate_shape=restored_shapes is None and
74        self.op.get_shape().is_fully_defined())
75
76
77class ResourceVariableSaveable(saveable_object.SaveableObject):
78  """SaveableObject implementation that handles ResourceVariables."""
79
80  def __init__(self, var, slice_spec, name):
81    self._var_device = var.device
82    self._var_shape = var.shape
83    if isinstance(var, ops.Tensor):
84      self.handle_op = var.op.inputs[0]
85      tensor = var
86    elif resource_variable_ops.is_resource_variable(var):
87
88      def _read_variable_closure(v):
89        def f():
90          with ops.device(v.device):
91            x = v.read_value()
92            # To allow variables placed on non-CPU devices to be checkpointed,
93            # we copy them to CPU on the same machine first.
94            with ops.device("/device:CPU:0"):
95              return array_ops.identity(x)
96        return f
97
98      self.handle_op = var.handle
99      tensor = _read_variable_closure(var)
100    else:
101      raise ValueError(
102          "Saveable is neither a resource variable nor a read operation."
103          " Got: %s" % repr(var))
104    spec = saveable_object.SaveSpec(tensor, slice_spec, name,
105                                    dtype=var.dtype, device=var.device)
106    super(ResourceVariableSaveable, self).__init__(var, [spec], name)
107
108  def restore(self, restored_tensors, restored_shapes):
109    restored_tensor = restored_tensors[0]
110    if restored_shapes is not None:
111      restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
112    # Copy the restored tensor to the variable's device.
113    with ops.device(self._var_device):
114      restored_tensor = array_ops.identity(restored_tensor)
115      return resource_variable_ops.shape_safe_assign_variable_handle(
116          self.handle_op, self._var_shape, restored_tensor)
117
118
119def _tensor_comes_from_variable(v):
120  return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS
121
122
123def saveable_objects_for_op(op, name):
124  """Create `SaveableObject`s from an operation.
125
126  Args:
127    op: A variable, operation, or SaveableObject to coerce into a
128      SaveableObject.
129    name: A string name for the SaveableObject.
130
131  Yields:
132    `SaveableObject`s which together save/restore `op`.
133
134  Raises:
135    TypeError: If `name` is not a string.
136    ValueError: For operations with no known conversion to SaveableObject.
137  """
138  if not isinstance(name, six.string_types):
139    raise TypeError(
140        "names_to_saveables must be a dict mapping string names to "
141        "trackable operations. Name is not a string: %s" % name)
142  if isinstance(op, saveable_object.SaveableObject):
143    yield op
144  elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
145    if isinstance(op, variables.PartitionedVariable):
146      op = list(op)
147    # A set of slices.
148    slice_name = None
149    # pylint: disable=protected-access
150    for variable in op:
151      if isinstance(variable, saveable_object.SaveableObject):
152        yield variable
153        continue
154      if not isinstance(variable, variables.Variable):
155        raise ValueError("Slices must all be Variables: %s" % variable)
156      if not variable._save_slice_info:
157        raise ValueError("Slices must all be slices: %s" % variable)
158      if slice_name is None:
159        slice_name = variable._save_slice_info.full_name
160      elif slice_name != variable._save_slice_info.full_name:
161        raise ValueError(
162            "Slices must all be from the same tensor: %s != %s" %
163            (slice_name, variable._save_slice_info.full_name))
164      if variable.op.type in ["Variable", "VariableV2",
165                              "AutoReloadVariable"]:
166        yield ReferenceVariableSaveable(
167            variable, variable._save_slice_info.spec, name)
168      else:
169        yield ResourceVariableSaveable(
170            variable, variable._save_slice_info.spec, name)
171    # pylint: enable=protected-access
172  elif isinstance(op, trackable.Trackable) and not isinstance(
173      op, variables.Variable):
174    # pylint: disable=protected-access
175    for attr, factory in op._gather_saveables_for_checkpoint().items():
176      if attr == trackable.VARIABLE_VALUE_KEY:
177        # Keep original name for classes masquerading as variables.
178        full_name = name
179      else:
180        full_name = name + "_" + attr
181      op = (factory(full_name) if callable(factory) else factory)
182      for op in saveable_objects_for_op(op, op.name):
183        yield op
184    # pylint: enable=protected-access
185  else:
186    # A variable or tensor.
187    if isinstance(op, resource_variable_ops.BaseResourceVariable):
188      # pylint: disable=protected-access
189      if op._in_graph_mode:
190        variable = op._graph_element
191      else:
192        variable = op
193      # pylint: enable=protected-access
194      yield ResourceVariableSaveable(variable, "", name)
195    else:
196      if context.executing_eagerly():
197        raise ValueError("Can only save/restore ResourceVariables when "
198                         "executing eagerly, got type: %s." % type(op))
199
200      variable = ops.convert_to_tensor(op, as_ref=True)
201      if not _tensor_comes_from_variable(variable):
202        raise TypeError("names_to_saveables must be a dict mapping string "
203                        "names to Tensors/Variables. Not a variable: %s" %
204                        variable)
205      if variable.op.type in ["Variable", "VariableV2",
206                              "AutoReloadVariable"]:
207        yield ReferenceVariableSaveable(variable, "", name)
208      else:
209        yield ResourceVariableSaveable(
210            variable, "", name)
211
212
213def op_list_to_dict(op_list, convert_variable_to_tensor=True):
214  """Create a dictionary of names to operation lists.
215
216  Args:
217    op_list: A (nested) list, tuple, or set of Variables or SaveableObjects.
218    convert_variable_to_tensor: Whether or not to convert single Variables
219      with no slice info into Tensors.
220
221  Returns:
222    A dictionary of names to the operations that must be saved under
223    that name.  Variables with save_slice_info are grouped together under the
224    same key in no particular order.
225
226  Raises:
227    TypeError: If the type of op_list or its elements is not supported.
228    ValueError: If at least two saveables share the same name.
229  """
230  if not isinstance(op_list, (list, tuple, set)):
231    raise TypeError("Variables to save should be passed in a dict or a "
232                    "list: %s" % op_list)
233  # List casting is necessary to support sets.
234  op_list = nest.flatten(list(op_list))
235  # When ResourceVariables are converted to Tensors, read ops are added to the
236  # graph. Sorting the op_list ensures that the resulting graph is always
237  # constructed in a deterministic way:
238  op_list = sorted(op_list, key=lambda x: x.name)
239  names_to_saveables = {}
240  # pylint: disable=protected-access
241  for var in op_list:
242    resource_or_ref_variable = (
243        isinstance(var, resource_variable_ops.BaseResourceVariable) or
244        isinstance(var, variables.RefVariable))
245
246    if isinstance(var, saveable_object.SaveableObject):
247      names_to_saveables[var.name] = var
248    elif isinstance(var, variables.PartitionedVariable):
249      if var.name in names_to_saveables:
250        raise ValueError("At least two variables have the same name: %s" %
251                         var.name)
252      names_to_saveables[var.name] = var
253    elif isinstance(var, variables.Variable) and var._save_slice_info:
254      name = var._save_slice_info.full_name
255      if name in names_to_saveables:
256        if not isinstance(names_to_saveables[name], list):
257          raise ValueError("Mixing slices and non-slices with the same name: "
258                           "%s" % name)
259        names_to_saveables[name].append(var)
260      else:
261        names_to_saveables[name] = [var]
262    elif isinstance(var, trackable.Trackable) and not resource_or_ref_variable:
263      trackable_saveables = [
264          (factory() if callable(factory) else factory)
265          for factory in var._gather_saveables_for_checkpoint().values()]
266      names_to_saveables.update(
267          op_list_to_dict(trackable_saveables))
268    else:
269      # Variables (reference and resource) have an _in_graph_mode property
270      # indicating whether they were created in a graph building context. We
271      # also get Tensors when graph building, which do not have this property.
272      if not getattr(var, "_in_graph_mode", True):
273        if not isinstance(var, resource_variable_ops.BaseResourceVariable):
274          raise ValueError(
275              "Can only save/restore ResourceVariables when eager execution "
276              "is enabled, type: %s." % type(var))
277        set_var = names_to_saveables.setdefault(var._shared_name, var)
278        if set_var is not var:
279          raise ValueError(
280              ("Two different ResourceVariable objects with the same "
281               "shared_name '%s' were passed to the Saver. This likely means "
282               "that they were created in different Graphs or isolation "
283               "contexts, and may not be checkpointed together.") %
284              (var._shared_name,))
285      else:
286        if convert_variable_to_tensor:
287          if isinstance(var, resource_variable_ops.BaseResourceVariable):
288            var = var._graph_element  # pylint: disable=protected-access
289          else:
290            var = ops.convert_to_tensor(var, as_ref=True)
291          if not _tensor_comes_from_variable(var):
292            raise TypeError("Variable to save is not a Variable: %s" % var)
293        if var.op.type == "ReadVariableOp":
294          name = var.op.inputs[0].op.name
295        else:
296          name = var.op.name
297        if name in names_to_saveables:
298          raise ValueError("At least two variables have the same name: %s" %
299                           name)
300        names_to_saveables[name] = var
301
302    # pylint: enable=protected-access
303  return names_to_saveables
304
305
306def _add_saveable(saveables, seen_ops, saveable):
307  """Adds the saveable to the saveables list.
308
309  Args:
310    saveables: List to append the SaveableObject to.
311    seen_ops: Set of the ops of the saveables already processed.  Used to
312      check that each saveable is only saved once.
313    saveable: The saveable.
314
315  Raises:
316    ValueError: If the saveable has already been processed.
317  """
318  if saveable.op in seen_ops:
319    raise ValueError("The same saveable will be restored with two names: %s" %
320                     saveable.name)
321  saveables.append(saveable)
322  seen_ops.add(saveable.op)
323
324
325def validate_and_slice_inputs(names_to_saveables):
326  """Returns the variables and names that will be used for a Saver.
327
328  Args:
329    names_to_saveables: A dict (k, v) where k is the name of an operation and
330       v is an operation to save or a BaseSaverBuilder.Saver.
331
332  Returns:
333    A list of SaveableObjects.
334
335  Raises:
336    TypeError: If any of the keys are not strings or any of the
337      values are not one of Tensor or Variable or a trackable operation.
338    ValueError: If the same operation is given in more than one value
339      (this also applies to slices of SlicedVariables).
340  """
341  if not isinstance(names_to_saveables, dict):
342    names_to_saveables = op_list_to_dict(names_to_saveables)
343
344  saveables = []
345  seen_ops = object_identity.ObjectIdentitySet()
346  for name, op in sorted(names_to_saveables.items(),
347                         # Avoid comparing ops, sort only by name.
348                         key=lambda x: x[0]):
349    for converted_saveable_object in saveable_objects_for_op(op, name):
350      _add_saveable(saveables, seen_ops, converted_saveable_object)
351  return saveables
352