1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for working with and creating SaveableObjects.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import six 21 22from tensorflow.python.eager import context 23from tensorflow.python.framework import device as pydev 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import resource_variable_ops 27from tensorflow.python.ops import state_ops 28from tensorflow.python.ops import variables 29from tensorflow.python.training.saving import saveable_object 30from tensorflow.python.training.tracking import base as trackable 31from tensorflow.python.util import nest 32from tensorflow.python.util import object_identity 33 34 35# Op names which identify variable reads which should be saved. 36_VARIABLE_OPS = set(["Variable", 37 "VariableV2", 38 "AutoReloadVariable", 39 "VarHandleOp", 40 "ReadVariableOp"]) 41 42 43def set_cpu0(device_string): 44 """Creates a new device string based on `device_string` but using /CPU:0. 45 46 If the device is already on /CPU:0, this is a no-op. 47 48 Args: 49 device_string: A device string. 50 51 Returns: 52 A device string. 53 """ 54 parsed_device = pydev.DeviceSpec.from_string(device_string) 55 parsed_device = parsed_device.replace(device_type="CPU", device_index=0) 56 return parsed_device.to_string() 57 58 59class ReferenceVariableSaveable(saveable_object.SaveableObject): 60 """SaveableObject implementation that handles reference variables.""" 61 62 def __init__(self, var, slice_spec, name): 63 spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype) 64 super(ReferenceVariableSaveable, self).__init__(var, [spec], name) 65 66 def restore(self, restored_tensors, restored_shapes): 67 restored_tensor = restored_tensors[0] 68 if restored_shapes is not None: 69 restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) 70 return state_ops.assign( 71 self.op, 72 restored_tensor, 73 validate_shape=restored_shapes is None and 74 self.op.get_shape().is_fully_defined()) 75 76 77class ResourceVariableSaveable(saveable_object.SaveableObject): 78 """SaveableObject implementation that handles ResourceVariables.""" 79 80 def __init__(self, var, slice_spec, name): 81 self._var_device = var.device 82 self._var_shape = var.shape 83 if isinstance(var, ops.Tensor): 84 self.handle_op = var.op.inputs[0] 85 tensor = var 86 elif resource_variable_ops.is_resource_variable(var): 87 88 def _read_variable_closure(v): 89 def f(): 90 with ops.device(v.device): 91 x = v.read_value() 92 # To allow variables placed on non-CPU devices to be checkpointed, 93 # we copy them to CPU on the same machine first. 94 with ops.device("/device:CPU:0"): 95 return array_ops.identity(x) 96 return f 97 98 self.handle_op = var.handle 99 tensor = _read_variable_closure(var) 100 else: 101 raise ValueError( 102 "Saveable is neither a resource variable nor a read operation." 103 " Got: %s" % repr(var)) 104 spec = saveable_object.SaveSpec(tensor, slice_spec, name, 105 dtype=var.dtype, device=var.device) 106 super(ResourceVariableSaveable, self).__init__(var, [spec], name) 107 108 def restore(self, restored_tensors, restored_shapes): 109 restored_tensor = restored_tensors[0] 110 if restored_shapes is not None: 111 restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) 112 # Copy the restored tensor to the variable's device. 113 with ops.device(self._var_device): 114 restored_tensor = array_ops.identity(restored_tensor) 115 return resource_variable_ops.shape_safe_assign_variable_handle( 116 self.handle_op, self._var_shape, restored_tensor) 117 118 119def _tensor_comes_from_variable(v): 120 return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS 121 122 123def saveable_objects_for_op(op, name): 124 """Create `SaveableObject`s from an operation. 125 126 Args: 127 op: A variable, operation, or SaveableObject to coerce into a 128 SaveableObject. 129 name: A string name for the SaveableObject. 130 131 Yields: 132 `SaveableObject`s which together save/restore `op`. 133 134 Raises: 135 TypeError: If `name` is not a string. 136 ValueError: For operations with no known conversion to SaveableObject. 137 """ 138 if not isinstance(name, six.string_types): 139 raise TypeError( 140 "names_to_saveables must be a dict mapping string names to " 141 "trackable operations. Name is not a string: %s" % name) 142 if isinstance(op, saveable_object.SaveableObject): 143 yield op 144 elif isinstance(op, (list, tuple, variables.PartitionedVariable)): 145 if isinstance(op, variables.PartitionedVariable): 146 op = list(op) 147 # A set of slices. 148 slice_name = None 149 # pylint: disable=protected-access 150 for variable in op: 151 if isinstance(variable, saveable_object.SaveableObject): 152 yield variable 153 continue 154 if not isinstance(variable, variables.Variable): 155 raise ValueError("Slices must all be Variables: %s" % variable) 156 if not variable._save_slice_info: 157 raise ValueError("Slices must all be slices: %s" % variable) 158 if slice_name is None: 159 slice_name = variable._save_slice_info.full_name 160 elif slice_name != variable._save_slice_info.full_name: 161 raise ValueError( 162 "Slices must all be from the same tensor: %s != %s" % 163 (slice_name, variable._save_slice_info.full_name)) 164 if variable.op.type in ["Variable", "VariableV2", 165 "AutoReloadVariable"]: 166 yield ReferenceVariableSaveable( 167 variable, variable._save_slice_info.spec, name) 168 else: 169 yield ResourceVariableSaveable( 170 variable, variable._save_slice_info.spec, name) 171 # pylint: enable=protected-access 172 elif isinstance(op, trackable.Trackable) and not isinstance( 173 op, variables.Variable): 174 # pylint: disable=protected-access 175 for attr, factory in op._gather_saveables_for_checkpoint().items(): 176 if attr == trackable.VARIABLE_VALUE_KEY: 177 # Keep original name for classes masquerading as variables. 178 full_name = name 179 else: 180 full_name = name + "_" + attr 181 op = (factory(full_name) if callable(factory) else factory) 182 for op in saveable_objects_for_op(op, op.name): 183 yield op 184 # pylint: enable=protected-access 185 else: 186 # A variable or tensor. 187 if isinstance(op, resource_variable_ops.BaseResourceVariable): 188 # pylint: disable=protected-access 189 if op._in_graph_mode: 190 variable = op._graph_element 191 else: 192 variable = op 193 # pylint: enable=protected-access 194 yield ResourceVariableSaveable(variable, "", name) 195 else: 196 if context.executing_eagerly(): 197 raise ValueError("Can only save/restore ResourceVariables when " 198 "executing eagerly, got type: %s." % type(op)) 199 200 variable = ops.convert_to_tensor(op, as_ref=True) 201 if not _tensor_comes_from_variable(variable): 202 raise TypeError("names_to_saveables must be a dict mapping string " 203 "names to Tensors/Variables. Not a variable: %s" % 204 variable) 205 if variable.op.type in ["Variable", "VariableV2", 206 "AutoReloadVariable"]: 207 yield ReferenceVariableSaveable(variable, "", name) 208 else: 209 yield ResourceVariableSaveable( 210 variable, "", name) 211 212 213def op_list_to_dict(op_list, convert_variable_to_tensor=True): 214 """Create a dictionary of names to operation lists. 215 216 Args: 217 op_list: A (nested) list, tuple, or set of Variables or SaveableObjects. 218 convert_variable_to_tensor: Whether or not to convert single Variables 219 with no slice info into Tensors. 220 221 Returns: 222 A dictionary of names to the operations that must be saved under 223 that name. Variables with save_slice_info are grouped together under the 224 same key in no particular order. 225 226 Raises: 227 TypeError: If the type of op_list or its elements is not supported. 228 ValueError: If at least two saveables share the same name. 229 """ 230 if not isinstance(op_list, (list, tuple, set)): 231 raise TypeError("Variables to save should be passed in a dict or a " 232 "list: %s" % op_list) 233 # List casting is necessary to support sets. 234 op_list = nest.flatten(list(op_list)) 235 # When ResourceVariables are converted to Tensors, read ops are added to the 236 # graph. Sorting the op_list ensures that the resulting graph is always 237 # constructed in a deterministic way: 238 op_list = sorted(op_list, key=lambda x: x.name) 239 names_to_saveables = {} 240 # pylint: disable=protected-access 241 for var in op_list: 242 resource_or_ref_variable = ( 243 isinstance(var, resource_variable_ops.BaseResourceVariable) or 244 isinstance(var, variables.RefVariable)) 245 246 if isinstance(var, saveable_object.SaveableObject): 247 names_to_saveables[var.name] = var 248 elif isinstance(var, variables.PartitionedVariable): 249 if var.name in names_to_saveables: 250 raise ValueError("At least two variables have the same name: %s" % 251 var.name) 252 names_to_saveables[var.name] = var 253 elif isinstance(var, variables.Variable) and var._save_slice_info: 254 name = var._save_slice_info.full_name 255 if name in names_to_saveables: 256 if not isinstance(names_to_saveables[name], list): 257 raise ValueError("Mixing slices and non-slices with the same name: " 258 "%s" % name) 259 names_to_saveables[name].append(var) 260 else: 261 names_to_saveables[name] = [var] 262 elif isinstance(var, trackable.Trackable) and not resource_or_ref_variable: 263 trackable_saveables = [ 264 (factory() if callable(factory) else factory) 265 for factory in var._gather_saveables_for_checkpoint().values()] 266 names_to_saveables.update( 267 op_list_to_dict(trackable_saveables)) 268 else: 269 # Variables (reference and resource) have an _in_graph_mode property 270 # indicating whether they were created in a graph building context. We 271 # also get Tensors when graph building, which do not have this property. 272 if not getattr(var, "_in_graph_mode", True): 273 if not isinstance(var, resource_variable_ops.BaseResourceVariable): 274 raise ValueError( 275 "Can only save/restore ResourceVariables when eager execution " 276 "is enabled, type: %s." % type(var)) 277 set_var = names_to_saveables.setdefault(var._shared_name, var) 278 if set_var is not var: 279 raise ValueError( 280 ("Two different ResourceVariable objects with the same " 281 "shared_name '%s' were passed to the Saver. This likely means " 282 "that they were created in different Graphs or isolation " 283 "contexts, and may not be checkpointed together.") % 284 (var._shared_name,)) 285 else: 286 if convert_variable_to_tensor: 287 if isinstance(var, resource_variable_ops.BaseResourceVariable): 288 var = var._graph_element # pylint: disable=protected-access 289 else: 290 var = ops.convert_to_tensor(var, as_ref=True) 291 if not _tensor_comes_from_variable(var): 292 raise TypeError("Variable to save is not a Variable: %s" % var) 293 if var.op.type == "ReadVariableOp": 294 name = var.op.inputs[0].op.name 295 else: 296 name = var.op.name 297 if name in names_to_saveables: 298 raise ValueError("At least two variables have the same name: %s" % 299 name) 300 names_to_saveables[name] = var 301 302 # pylint: enable=protected-access 303 return names_to_saveables 304 305 306def _add_saveable(saveables, seen_ops, saveable): 307 """Adds the saveable to the saveables list. 308 309 Args: 310 saveables: List to append the SaveableObject to. 311 seen_ops: Set of the ops of the saveables already processed. Used to 312 check that each saveable is only saved once. 313 saveable: The saveable. 314 315 Raises: 316 ValueError: If the saveable has already been processed. 317 """ 318 if saveable.op in seen_ops: 319 raise ValueError("The same saveable will be restored with two names: %s" % 320 saveable.name) 321 saveables.append(saveable) 322 seen_ops.add(saveable.op) 323 324 325def validate_and_slice_inputs(names_to_saveables): 326 """Returns the variables and names that will be used for a Saver. 327 328 Args: 329 names_to_saveables: A dict (k, v) where k is the name of an operation and 330 v is an operation to save or a BaseSaverBuilder.Saver. 331 332 Returns: 333 A list of SaveableObjects. 334 335 Raises: 336 TypeError: If any of the keys are not strings or any of the 337 values are not one of Tensor or Variable or a trackable operation. 338 ValueError: If the same operation is given in more than one value 339 (this also applies to slices of SlicedVariables). 340 """ 341 if not isinstance(names_to_saveables, dict): 342 names_to_saveables = op_list_to_dict(names_to_saveables) 343 344 saveables = [] 345 seen_ops = object_identity.ObjectIdentitySet() 346 for name, op in sorted(names_to_saveables.items(), 347 # Avoid comparing ops, sort only by name. 348 key=lambda x: x[0]): 349 for converted_saveable_object in saveable_objects_for_op(op, name): 350 _add_saveable(saveables, seen_ops, converted_saveable_object) 351 return saveables 352