1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for working with and creating SaveableObjects.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import six 21 22from tensorflow.python.eager import context 23from tensorflow.python.framework import device as pydev 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import resource_variable_ops 27from tensorflow.python.ops import state_ops 28from tensorflow.python.ops import variables 29from tensorflow.python.training.saving import saveable_object 30from tensorflow.python.training.tracking import base as trackable 31 32 33# Op names which identify variable reads which should be saved. 34_VARIABLE_OPS = set(["Variable", 35 "VariableV2", 36 "AutoReloadVariable", 37 "VarHandleOp", 38 "ReadVariableOp"]) 39 40 41def set_cpu0(device_string): 42 """Creates a new device string based on `device_string` but using /CPU:0. 43 44 If the device is already on /CPU:0, this is a no-op. 45 46 Args: 47 device_string: A device string. 48 49 Returns: 50 A device string. 51 """ 52 parsed_device = pydev.DeviceSpec.from_string(device_string) 53 parsed_device.device_type = "CPU" 54 parsed_device.device_index = 0 55 return parsed_device.to_string() 56 57 58class ReferenceVariableSaveable(saveable_object.SaveableObject): 59 """SaveableObject implementation that handles reference variables.""" 60 61 def __init__(self, var, slice_spec, name): 62 spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype) 63 super(ReferenceVariableSaveable, self).__init__(var, [spec], name) 64 65 def restore(self, restored_tensors, restored_shapes): 66 restored_tensor = restored_tensors[0] 67 if restored_shapes is not None: 68 restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) 69 return state_ops.assign( 70 self.op, 71 restored_tensor, 72 validate_shape=restored_shapes is None and 73 self.op.get_shape().is_fully_defined()) 74 75 76class ResourceVariableSaveable(saveable_object.SaveableObject): 77 """SaveableObject implementation that handles ResourceVariables.""" 78 79 def __init__(self, var, slice_spec, name): 80 self._var_device = var.device 81 self._var_shape = var.shape 82 if isinstance(var, ops.Tensor): 83 self.handle_op = var.op.inputs[0] 84 tensor = var 85 elif isinstance(var, resource_variable_ops.ResourceVariable): 86 87 def _read_variable_closure(v): 88 def f(): 89 with ops.device(v.device): 90 x = v.read_value() 91 # To allow variables placed on non-CPU devices to be checkpointed, 92 # we copy them to CPU on the same machine first. 93 with ops.device("/device:CPU:0"): 94 return array_ops.identity(x) 95 return f 96 97 self.handle_op = var.handle 98 tensor = _read_variable_closure(var) 99 else: 100 raise ValueError( 101 "Saveable is neither a resource variable nor a read operation." 102 " Got: %s" % repr(var)) 103 spec = saveable_object.SaveSpec(tensor, slice_spec, name, 104 dtype=var.dtype) 105 super(ResourceVariableSaveable, self).__init__(var, [spec], name) 106 107 def restore(self, restored_tensors, restored_shapes): 108 restored_tensor = restored_tensors[0] 109 if restored_shapes is not None: 110 restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) 111 # Copy the restored tensor to the variable's device. 112 with ops.device(self._var_device): 113 restored_tensor = array_ops.identity(restored_tensor) 114 return resource_variable_ops.shape_safe_assign_variable_handle( 115 self.handle_op, self._var_shape, restored_tensor) 116 117 118def _tensor_comes_from_variable(v): 119 return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS 120 121 122def saveable_objects_for_op(op, name): 123 """Create `SaveableObject`s from an operation. 124 125 Args: 126 op: A variable, operation, or SaveableObject to coerce into a 127 SaveableObject. 128 name: A string name for the SaveableObject. 129 130 Yields: 131 `SaveableObject`s which together save/restore `op`. 132 133 Raises: 134 TypeError: If `name` is not a string. 135 ValueError: For operations with no known conversion to SaveableObject. 136 """ 137 if not isinstance(name, six.string_types): 138 raise TypeError( 139 "names_to_saveables must be a dict mapping string names to " 140 "trackable operations. Name is not a string: %s" % name) 141 if isinstance(op, saveable_object.SaveableObject): 142 yield op 143 elif isinstance(op, (list, tuple, variables.PartitionedVariable)): 144 if isinstance(op, variables.PartitionedVariable): 145 op = list(op) 146 # A set of slices. 147 slice_name = None 148 # pylint: disable=protected-access 149 for variable in op: 150 if not isinstance(variable, variables.Variable): 151 raise ValueError("Slices must all be Variables: %s" % variable) 152 if not variable._save_slice_info: 153 raise ValueError("Slices must all be slices: %s" % variable) 154 if slice_name is None: 155 slice_name = variable._save_slice_info.full_name 156 elif slice_name != variable._save_slice_info.full_name: 157 raise ValueError( 158 "Slices must all be from the same tensor: %s != %s" % 159 (slice_name, variable._save_slice_info.full_name)) 160 if variable.op.type in ["Variable", "VariableV2", 161 "AutoReloadVariable"]: 162 yield ReferenceVariableSaveable( 163 variable, variable._save_slice_info.spec, name) 164 else: 165 yield ResourceVariableSaveable( 166 variable, variable._save_slice_info.spec, name) 167 # pylint: enable=protected-access 168 elif isinstance(op, trackable.Trackable) and not isinstance( 169 op, variables.Variable): 170 # pylint: disable=protected-access 171 for attr, factory in op._gather_saveables_for_checkpoint().items(): 172 if attr == trackable.VARIABLE_VALUE_KEY: 173 # Keep original name for classes masquerading as variables. 174 full_name = name 175 else: 176 full_name = name + "_" + attr 177 op = (factory(full_name) if callable(factory) else factory) 178 for op in saveable_objects_for_op(op, op.name): 179 yield op 180 # pylint: enable=protected-access 181 else: 182 # A variable or tensor. 183 if isinstance(op, resource_variable_ops.ResourceVariable): 184 # pylint: disable=protected-access 185 if op._in_graph_mode: 186 variable = op._graph_element 187 else: 188 variable = op 189 # pylint: enable=protected-access 190 yield ResourceVariableSaveable(variable, "", name) 191 else: 192 with ops.init_scope(): 193 if context.executing_eagerly(): 194 raise ValueError("Can only save/restore ResourceVariables when " 195 "executing eagerly, got type: %s." % type(op)) 196 197 variable = ops.internal_convert_to_tensor(op, as_ref=True) 198 if not _tensor_comes_from_variable(variable): 199 raise TypeError("names_to_saveables must be a dict mapping string " 200 "names to Tensors/Variables. Not a variable: %s" % 201 variable) 202 if variable.op.type in ["Variable", "VariableV2", 203 "AutoReloadVariable"]: 204 yield ReferenceVariableSaveable(variable, "", name) 205 else: 206 yield ResourceVariableSaveable( 207 variable, "", name) 208 209 210def op_list_to_dict(op_list, convert_variable_to_tensor=True): 211 """Create a dictionary of names to operation lists. 212 213 Args: 214 op_list: A list, tuple, or set of Variables or SaveableObjects. 215 convert_variable_to_tensor: Whether or not to convert single Variables 216 with no slice info into Tensors. 217 218 Returns: 219 A dictionary of names to the operations that must be saved under 220 that name. Variables with save_slice_info are grouped together under the 221 same key in no particular order. 222 223 Raises: 224 TypeError: If the type of op_list or its elements is not supported. 225 ValueError: If at least two saveables share the same name. 226 """ 227 if not isinstance(op_list, (list, tuple, set)): 228 raise TypeError("Variables to save should be passed in a dict or a " 229 "list: %s" % op_list) 230 # When ResourceVariables are converted to Tensors, read ops are added to the 231 # graph. Sorting the op_list ensures that the resulting graph is always 232 # constructed in a deterministic way: 233 op_list = sorted(op_list, key=lambda x: x.name) 234 names_to_saveables = {} 235 # pylint: disable=protected-access 236 for var in op_list: 237 if isinstance(var, saveable_object.SaveableObject): 238 names_to_saveables[var.name] = var 239 elif isinstance(var, variables.PartitionedVariable): 240 if var.name in names_to_saveables: 241 raise ValueError("At least two variables have the same name: %s" % 242 var.name) 243 names_to_saveables[var.name] = var 244 elif isinstance(var, variables.Variable) and var._save_slice_info: 245 name = var._save_slice_info.full_name 246 if name in names_to_saveables: 247 if not isinstance(names_to_saveables[name], list): 248 raise ValueError("Mixing slices and non-slices with the same name: " 249 "%s" % name) 250 names_to_saveables[name].append(var) 251 else: 252 names_to_saveables[name] = [var] 253 elif (isinstance(var, trackable.Trackable) 254 and not isinstance(var, variables.Variable)): 255 trackable_saveables = [ 256 (factory() if callable(factory) else factory) 257 for factory in var._gather_saveables_for_checkpoint().values()] 258 names_to_saveables.update( 259 op_list_to_dict(trackable_saveables)) 260 else: 261 # Variables (reference and resource) have an _in_graph_mode property 262 # indicating whether they were created in a graph building context. We 263 # also get Tensors when graph building, which do not have this property. 264 if not getattr(var, "_in_graph_mode", True): 265 if not isinstance(var, resource_variable_ops.ResourceVariable): 266 raise ValueError( 267 "Can only save/restore ResourceVariables when eager execution " 268 "is enabled, type: %s." % type(var)) 269 set_var = names_to_saveables.setdefault(var._shared_name, var) 270 if set_var is not var: 271 raise ValueError( 272 ("Two different ResourceVariable objects with the same " 273 "shared_name '%s' were passed to the Saver. This likely means " 274 "that they were created in different Graphs or isolation " 275 "contexts, and may not be checkpointed together.") % 276 (var._shared_name,)) 277 else: 278 if convert_variable_to_tensor: 279 if isinstance(var, resource_variable_ops.ResourceVariable): 280 var = var._graph_element # pylint: disable=protected-access 281 else: 282 var = ops.internal_convert_to_tensor(var, as_ref=True) 283 if not _tensor_comes_from_variable(var): 284 raise TypeError("Variable to save is not a Variable: %s" % var) 285 if var.op.type == "ReadVariableOp": 286 name = var.op.inputs[0].op.name 287 else: 288 name = var.op.name 289 if name in names_to_saveables: 290 raise ValueError("At least two variables have the same name: %s" % 291 name) 292 names_to_saveables[name] = var 293 294 # pylint: enable=protected-access 295 return names_to_saveables 296 297 298def _add_saveable(saveables, seen_ops, saveable): 299 """Adds the saveable to the saveables list. 300 301 Args: 302 saveables: List to append the SaveableObject to. 303 seen_ops: Set of the ops of the saveables already processed. Used to 304 check that each saveable is only saved once. 305 saveable: The saveable. 306 307 Raises: 308 ValueError: If the saveable has already been processed. 309 """ 310 if saveable.op in seen_ops: 311 raise ValueError("The same saveable will be restored with two names: %s" % 312 saveable.name) 313 saveables.append(saveable) 314 seen_ops.add(saveable.op) 315 316 317def validate_and_slice_inputs(names_to_saveables): 318 """Returns the variables and names that will be used for a Saver. 319 320 Args: 321 names_to_saveables: A dict (k, v) where k is the name of an operation and 322 v is an operation to save or a BaseSaverBuilder.Saver. 323 324 Returns: 325 A list of SaveableObjects. 326 327 Raises: 328 TypeError: If any of the keys are not strings or any of the 329 values are not one of Tensor or Variable or a trackable operation. 330 ValueError: If the same operation is given in more than one value 331 (this also applies to slices of SlicedVariables). 332 """ 333 if not isinstance(names_to_saveables, dict): 334 names_to_saveables = op_list_to_dict(names_to_saveables) 335 336 saveables = [] 337 seen_ops = set() 338 for name, op in sorted(names_to_saveables.items(), 339 # Avoid comparing ops, sort only by name. 340 key=lambda x: x[0]): 341 for converted_saveable_object in saveable_objects_for_op(op, name): 342 _add_saveable(saveables, seen_ops, converted_saveable_object) 343 return saveables 344