1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Variable functions. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23import re 24 25from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope 26from tensorflow.contrib.framework.python.ops import gen_variable_ops 27from tensorflow.contrib.util import loader 28from tensorflow.core.protobuf import saver_pb2 29from tensorflow.python import pywrap_tensorflow 30from tensorflow.python.framework import device as tf_device 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import variable_scope 36from tensorflow.python.platform import resource_loader 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.training import saver as tf_saver 39from tensorflow.python.training import training_util 40from tensorflow.python.util.deprecation import deprecated 41 42 43__all__ = ['add_model_variable', 44 'assert_global_step', 45 'assert_or_get_global_step', 46 'assign_from_checkpoint', 47 'assign_from_checkpoint_fn', 48 'assign_from_values', 49 'assign_from_values_fn', 50 'create_global_step', 51 'filter_variables', 52 'get_global_step', 53 'get_or_create_global_step', 54 'get_local_variables', 55 'get_model_variables', 56 'get_trainable_variables', 57 'get_unique_variable', 58 'get_variables_by_name', 59 'get_variables_by_suffix', 60 'get_variable_full_name', 61 'get_variables_to_restore', 62 'get_variables', 63 'global_variable', 64 'local_variable', 65 'model_variable', 66 'variable', 67 'VariableDeviceChooser', 68 'zero_initializer'] 69 70 71def zero_initializer(ref, use_locking=True, name="zero_initializer"): 72 """Initialize 'ref' with all zeros, ref tensor should be uninitialized. 73 If already initialized, you will get ValueError. This op is intended to 74 save memory during initialization. 75 Args: 76 ref: ref of the tensor need to be zero initialized. 77 name: optional name for this operation. 78 Returns: 79 ref that initialized. 80 Raises: 81 ValueError: If ref tensor is initialized. 82 """ 83 loader.load_op_library( 84 resource_loader.get_path_to_datafile("_variable_ops.so")) 85 return gen_variable_ops.zero_initializer(ref, name=name) 86 87@deprecated(None, "Please switch to tf.train.assert_global_step") 88def assert_global_step(global_step_tensor): 89 training_util.assert_global_step(global_step_tensor) 90 91 92def assert_or_get_global_step(graph=None, global_step_tensor=None): 93 """Verifies that a global step tensor is valid or gets one if None is given. 94 95 If `global_step_tensor` is not None, check that it is a valid global step 96 tensor (using `assert_global_step`). Otherwise find a global step tensor using 97 `get_global_step` and return it. 98 99 Args: 100 graph: The graph to find the global step tensor for. 101 global_step_tensor: The tensor to check for suitability as a global step. 102 If None is given (the default), find a global step tensor. 103 104 Returns: 105 A tensor suitable as a global step, or `None` if none was provided and none 106 was found. 107 """ 108 if global_step_tensor is None: 109 # Get the global step tensor the same way the supervisor would. 110 global_step_tensor = get_global_step(graph) 111 else: 112 assert_global_step(global_step_tensor) 113 return global_step_tensor 114 115@deprecated(None, "Please switch to tf.train.get_global_step") 116def get_global_step(graph=None): 117 return training_util.get_global_step(graph) 118 119@deprecated(None, "Please switch to tf.train.create_global_step") 120def create_global_step(graph=None): 121 """Create global step tensor in graph. 122 123 This API is deprecated. Use core framework training version instead. 124 125 Args: 126 graph: The graph in which to create the global step tensor. If missing, 127 use default graph. 128 129 Returns: 130 Global step tensor. 131 132 Raises: 133 ValueError: if global step tensor is already defined. 134 """ 135 return training_util.create_global_step(graph) 136 137@deprecated(None, "Please switch to tf.train.get_or_create_global_step") 138def get_or_create_global_step(graph=None): 139 """Returns and create (if necessary) the global step tensor. 140 141 Args: 142 graph: The graph in which to create the global step tensor. If missing, use 143 default graph. 144 145 Returns: 146 The global step tensor. 147 """ 148 return training_util.get_or_create_global_step(graph) 149 150 151def local_variable(initial_value, 152 validate_shape=True, 153 name=None, 154 use_resource=None): 155 """Create a variable with a value and add it to `GraphKeys.LOCAL_VARIABLES`. 156 157 Args: 158 initial_value: See variables.Variable.__init__. 159 validate_shape: See variables.Variable.__init__. 160 name: See variables.Variable.__init__. 161 use_resource: If `True` use a ResourceVariable instead of a Variable. 162 Returns: 163 New variable. 164 """ 165 return variable_scope.variable( 166 initial_value, trainable=False, 167 collections=[ops.GraphKeys.LOCAL_VARIABLES], 168 validate_shape=validate_shape, 169 use_resource=use_resource, 170 name=name) 171 172 173def global_variable(initial_value, 174 validate_shape=True, 175 name=None, 176 use_resource=None): 177 """Create a variable with a value and add it to `GraphKeys.GLOBAL_VARIABLES`. 178 179 Args: 180 initial_value: See variables.Variable.__init__. 181 validate_shape: See variables.Variable.__init__. 182 name: See variables.Variable.__init__. 183 use_resource: If `True` use a ResourceVariable instead of a Variable. 184 Returns: 185 New variable. 186 """ 187 return variable_scope.variable( 188 initial_value, trainable=False, 189 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 190 validate_shape=validate_shape, 191 use_resource=use_resource, 192 name=name) 193 194 195@contrib_add_arg_scope 196def variable(name, shape=None, dtype=None, initializer=None, 197 regularizer=None, trainable=True, collections=None, 198 caching_device=None, device=None, 199 partitioner=None, custom_getter=None, use_resource=None): 200 """Gets an existing variable with these parameters or creates a new one. 201 202 Args: 203 name: the name of the new or existing variable. 204 shape: shape of the new or existing variable. 205 dtype: type of the new or existing variable (defaults to `DT_FLOAT`). 206 initializer: initializer for the variable if one is created. 207 regularizer: a (Tensor -> Tensor or None) function; the result of 208 applying it on a newly created variable will be added to the collection 209 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 210 trainable: If `True` also add the variable to the graph collection 211 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 212 collections: A list of collection names to which the Variable will be added. 213 If None it would default to `tf.GraphKeys.GLOBAL_VARIABLES`. 214 caching_device: Optional device string or function describing where the 215 Variable should be cached for reading. Defaults to the Variable's 216 device. 217 device: Optional device to place the variable. It can be an string or a 218 function that is called to get the device for the variable. 219 partitioner: Optional callable that accepts a fully defined `TensorShape` 220 and dtype of the `Variable` to be created, and returns a list of 221 partitions for each axis (currently only one axis can be partitioned). 222 custom_getter: Callable that allows overwriting the internal 223 get_variable method and has to have the same signature. 224 use_resource: If `True` use a ResourceVariable instead of a Variable. 225 226 Returns: 227 The created or existing variable. 228 """ 229 collections = list(collections if collections is not None 230 else [ops.GraphKeys.GLOBAL_VARIABLES]) 231 232 # Remove duplicates 233 collections = list(set(collections)) 234 getter = variable_scope.get_variable 235 if custom_getter is not None: 236 getter = functools.partial(custom_getter, 237 reuse=variable_scope.get_variable_scope().reuse) 238 with ops.device(device or ''): 239 return getter(name, shape=shape, dtype=dtype, 240 initializer=initializer, 241 regularizer=regularizer, 242 trainable=trainable, 243 collections=collections, 244 caching_device=caching_device, 245 partitioner=partitioner, 246 use_resource=use_resource) 247 248 249@contrib_add_arg_scope 250def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None, 251 regularizer=None, trainable=True, collections=None, 252 caching_device=None, device=None, partitioner=None, 253 custom_getter=None, use_resource=None): 254 """Gets an existing model variable with these parameters or creates a new one. 255 256 Args: 257 name: the name of the new or existing variable. 258 shape: shape of the new or existing variable. 259 dtype: type of the new or existing variable (defaults to `DT_FLOAT`). 260 initializer: initializer for the variable if one is created. 261 regularizer: a (Tensor -> Tensor or None) function; the result of 262 applying it on a newly created variable will be added to the collection 263 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 264 trainable: If `True` also add the variable to the graph collection 265 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 266 collections: A list of collection names to which the Variable will be added. 267 Note that the variable is always also added to the 268 `GraphKeys.GLOBAL_VARIABLES` and `GraphKeys.MODEL_VARIABLES` collections. 269 caching_device: Optional device string or function describing where the 270 Variable should be cached for reading. Defaults to the Variable's 271 device. 272 device: Optional device to place the variable. It can be an string or a 273 function that is called to get the device for the variable. 274 partitioner: Optional callable that accepts a fully defined `TensorShape` 275 and dtype of the `Variable` to be created, and returns a list of 276 partitions for each axis (currently only one axis can be partitioned). 277 custom_getter: Callable that allows overwriting the internal 278 get_variable method and has to have the same signature. 279 use_resource: If `True` use a ResourceVariable instead of a Variable. 280 281 Returns: 282 The created or existing variable. 283 """ 284 collections = list(collections or []) 285 collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES] 286 var = variable(name, shape=shape, dtype=dtype, 287 initializer=initializer, regularizer=regularizer, 288 trainable=trainable, collections=collections, 289 caching_device=caching_device, device=device, 290 partitioner=partitioner, custom_getter=custom_getter, 291 use_resource=use_resource) 292 return var 293 294 295def add_model_variable(var): 296 """Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection. 297 298 Args: 299 var: a variable. 300 """ 301 if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES): 302 ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var) 303 304 305def get_variables(scope=None, suffix=None, 306 collection=ops.GraphKeys.GLOBAL_VARIABLES): 307 """Gets the list of variables, filtered by scope and/or suffix. 308 309 Args: 310 scope: an optional scope for filtering the variables to return. Can be a 311 variable scope or a string. 312 suffix: an optional suffix for filtering the variables to return. 313 collection: in which collection search for. Defaults to 314 `GraphKeys.GLOBAL_VARIABLES`. 315 316 Returns: 317 a list of variables in collection with scope and suffix. 318 """ 319 if isinstance(scope, variable_scope.VariableScope): 320 scope = scope.name 321 if suffix is not None: 322 if ':' not in suffix: 323 suffix += ':' 324 scope = (scope or '') + '.*' + suffix 325 return ops.get_collection(collection, scope) 326 327 328def get_model_variables(scope=None, suffix=None): 329 """Gets the list of model variables, filtered by scope and/or suffix. 330 331 Args: 332 scope: an optional scope for filtering the variables to return. 333 suffix: an optional suffix for filtering the variables to return. 334 335 Returns: 336 a list of variables in collection with scope and suffix. 337 """ 338 return get_variables(scope, suffix, ops.GraphKeys.MODEL_VARIABLES) 339 340 341def get_local_variables(scope=None, suffix=None): 342 """Gets the list of local variables, filtered by scope and/or suffix. 343 344 Args: 345 scope: an optional scope for filtering the variables to return. 346 suffix: an optional suffix for filtering the variables to return. 347 348 Returns: 349 a list of variables in collection with scope and suffix. 350 """ 351 return get_variables(scope, suffix, ops.GraphKeys.LOCAL_VARIABLES) 352 353 354def get_trainable_variables(scope=None, suffix=None): 355 """Gets the list of trainable variables, filtered by scope and/or suffix. 356 357 Args: 358 scope: an optional scope for filtering the variables to return. 359 suffix: an optional suffix for filtering the variables to return. 360 361 Returns: 362 a list of variables in the trainable collection with scope and suffix. 363 """ 364 return get_variables(scope, suffix, ops.GraphKeys.TRAINABLE_VARIABLES) 365 366 367def get_variables_to_restore(include=None, exclude=None): 368 """Gets the list of the variables to restore. 369 370 Args: 371 include: an optional list/tuple of scope strings for filtering which 372 variables from the VARIABLES collection to include. None would include all 373 the variables. 374 exclude: an optional list/tuple of scope strings for filtering which 375 variables from the VARIABLES collection to exclude. None it would not 376 exclude any. 377 378 Returns: 379 a list of variables to restore. 380 381 Raises: 382 TypeError: include or exclude is provided but is not a list or a tuple. 383 """ 384 if include is None: 385 # Include all variables. 386 vars_to_include = get_variables() 387 else: 388 if not isinstance(include, (list, tuple)): 389 raise TypeError('include is provided but is not a list or a tuple.') 390 vars_to_include = [] 391 for scope in include: 392 vars_to_include += get_variables(scope) 393 vars_to_exclude = set() 394 if exclude is not None: 395 if not isinstance(exclude, (list, tuple)): 396 raise TypeError('exclude is provided but is not a list or a tuple.') 397 for scope in exclude: 398 vars_to_exclude |= set(get_variables(scope)) 399 # Exclude the variables in vars_to_exclude 400 return [v for v in vars_to_include if v not in vars_to_exclude] 401 402 403def get_variables_by_suffix(suffix, scope=None): 404 """Gets the list of variables that end with the given suffix. 405 406 Args: 407 suffix: suffix for filtering the variables to return. 408 scope: an optional scope for filtering the variables to return. 409 410 Returns: 411 a copied list of variables with the given name and prefix. 412 """ 413 return get_variables(scope=scope, suffix=suffix) 414 415 416def get_variables_by_name(given_name, scope=None): 417 """Gets the list of variables that were given that name. 418 419 Args: 420 given_name: name given to the variable without any scope. 421 scope: an optional scope for filtering the variables to return. 422 423 Returns: 424 a copied list of variables with the given name and scope. 425 """ 426 suffix = '/' + given_name + ':|^' + given_name + ':' 427 return get_variables(scope=scope, suffix=suffix) 428 429 430def get_unique_variable(var_op_name): 431 """Gets the variable uniquely identified by that var_op_name. 432 433 Args: 434 var_op_name: the full name of the variable op, including the scope. 435 436 Returns: 437 a tensorflow variable. 438 439 Raises: 440 ValueError: if no variable uniquely identified by the name exists. 441 """ 442 candidates = get_variables(scope=var_op_name) 443 if not candidates: 444 raise ValueError('Couldn\'t find variable %s' % var_op_name) 445 446 for candidate in candidates: 447 if candidate.op.name == var_op_name: 448 return candidate 449 raise ValueError('Variable %s does not uniquely identify a variable' % 450 var_op_name) 451 452 453def assign_from_values(var_names_to_values): 454 """Creates an assignment operation from a given mapping. 455 456 This function provides a mechanism for performing assignment of variables 457 to values in a way that does not fill the graph with large assignment values. 458 459 Args: 460 var_names_to_values: A map from variable names to values. 461 462 Returns: 463 assign_op: An `Operation` that assigns each of the given variables to the 464 requested values. 465 feed_dict: The feed dictionary to use when evaluating `assign_op`. 466 467 Raises: 468 ValueError: if any of the given variable names were not found. 469 """ 470 feed_dict = {} 471 assign_ops = [] 472 473 for var_name in var_names_to_values: 474 var_value = var_names_to_values[var_name] 475 var = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, var_name) 476 if not var: 477 raise ValueError('Variable %s wasn\'t found' % var_name) 478 elif len(var) > 1: 479 # tf.get_collection is just a filter on the prefix: find the exact match: 480 found = False 481 for v in var: 482 if v.op.name == var_name: 483 var = v 484 found = True 485 break 486 487 if not found: 488 raise ValueError('Variable %s doesn\'t uniquely identify a variable' % 489 var_name) 490 else: 491 var = var[0] 492 493 # TODO(nsilberman): ensure placeholder and assign are on the same device. 494 # Assign a placeholder to the value that will be filled later. 495 placeholder_name = 'placeholder/' + var.op.name 496 placeholder_value = array_ops.placeholder( 497 dtype=var.dtype.base_dtype, 498 shape=var.get_shape(), 499 name=placeholder_name) 500 assign_ops.append(var.assign(placeholder_value)) 501 502 feed_dict[placeholder_value] = var_value.reshape(var.get_shape()) 503 504 assign_op = control_flow_ops.group(*assign_ops) 505 return assign_op, feed_dict 506 507 508def assign_from_values_fn(var_names_to_values): 509 """Returns a function that assigns specific variables from the given values. 510 511 This function provides a mechanism for performing assignment of variables 512 to values in a way that does not fill the graph with large assignment values. 513 514 Args: 515 var_names_to_values: A map from variable names to values. 516 517 Returns: 518 A function that takes a single argument, a `tf.Session`, that applies the 519 assignment operation. 520 521 Raises: 522 ValueError: if any of the given variable names were not found. 523 """ 524 assign_op, feed_dict = assign_from_values(var_names_to_values) 525 def callback(session): 526 return session.run(assign_op, feed_dict) 527 return callback 528 529 530# pylint: disable=protected-access 531# Currently variable_scope doesn't provide very good APIs to access 532# all variables under scope and retrieve and check existing scopes. 533def get_variable_full_name(var): 534 """Returns the full name of a variable. 535 536 For normal Variables, this is the same as the var.op.name. For 537 sliced or PartitionedVariables, this name is the same for all the 538 slices/partitions. In both cases, this is normally the name used in 539 a checkpoint file. 540 541 Args: 542 var: A `Variable` object. 543 544 Returns: 545 A string that is the full name. 546 """ 547 if var._save_slice_info: 548 return var._save_slice_info.full_name 549 else: 550 return var.op.name 551 552 553# TODO(nsilberman): add flag to load exponential moving averages instead 554# 555# TODO(sguada): Update docs in slim/g3doc/index.md to describe 556# the new feature where the var_list dictionary can have values that 557# are each a list of Variables. 558def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False): 559 """Creates an operation to assign specific variables from a checkpoint. 560 561 Args: 562 model_path: The full path to the model checkpoint. To get latest checkpoint 563 use `model_path = tf.train.latest_checkpoint(checkpoint_dir)` 564 var_list: A list of (possibly partitioned) `Variable` objects 565 or a dictionary mapping names in the checkpoint to the 566 corresponding variables or list of variables to initialize 567 from that checkpoint value. For partitioned Variables, the 568 name in the checkpoint must be the full variable, not the 569 name of the partitioned variable, eg. "my_var" rather than 570 "my_var/part_4". If empty, returns no_op(), {}. 571 ignore_missing_vars: Boolean, if True ignore variables missing in the 572 checkpoint with a warning instead of failing. 573 574 Returns: 575 the restore_op and the feed_dict that need to be run to restore var_list. 576 577 Raises: 578 ValueError: If `ignore_missing_vars` is False and the checkpoint specified 579 at `model_path` is missing one of the variables in `var_list`. 580 """ 581 # Normalize var_list into a dictionary mapping names in the 582 # checkpoint to the list of variables to initialize from that 583 # checkpoint variable. Sliced (including partitioned) variables will 584 # end up under the same key. 585 grouped_vars = {} 586 if isinstance(var_list, (tuple, list)): 587 for var in var_list: 588 ckpt_name = get_variable_full_name(var) 589 if ckpt_name not in grouped_vars: 590 grouped_vars[ckpt_name] = [] 591 grouped_vars[ckpt_name].append(var) 592 593 else: 594 for ckpt_name, value in var_list.items(): 595 if isinstance(value, (tuple, list)): 596 grouped_vars[ckpt_name] = value 597 else: 598 grouped_vars[ckpt_name] = [value] 599 600 # Read each checkpoint entry. Create a placeholder variable and 601 # add the (possibly sliced) data from the checkpoint to the feed_dict. 602 reader = pywrap_tensorflow.NewCheckpointReader(model_path) 603 feed_dict = {} 604 assign_ops = [] 605 for ckpt_name in grouped_vars: 606 if not reader.has_tensor(ckpt_name): 607 log_str = 'Checkpoint is missing variable [%s]' % ckpt_name 608 if ignore_missing_vars: 609 logging.warning(log_str) 610 continue 611 else: 612 raise ValueError(log_str) 613 ckpt_value = reader.get_tensor(ckpt_name) 614 615 for var in grouped_vars[ckpt_name]: 616 placeholder_tensor = array_ops.placeholder( 617 dtype=var.dtype.base_dtype, 618 shape=var.get_shape(), 619 name='placeholder/' + var.op.name) 620 assign_ops.append(var.assign(placeholder_tensor)) 621 622 if not var._save_slice_info: 623 if var.get_shape() != ckpt_value.shape: 624 raise ValueError( 625 'Total size of new array must be unchanged for %s ' 626 'lh_shape: [%s], rh_shape: [%s]' 627 % (ckpt_name, str(ckpt_value.shape), str(var.get_shape()))) 628 629 feed_dict[placeholder_tensor] = ckpt_value.reshape(ckpt_value.shape) 630 else: 631 slice_dims = zip(var._save_slice_info.var_offset, 632 var._save_slice_info.var_shape) 633 slice_dims = [(start, start + size) for (start, size) in slice_dims] 634 slice_dims = [slice(*x) for x in slice_dims] 635 slice_value = ckpt_value[slice_dims] 636 slice_value = slice_value.reshape(var._save_slice_info.var_shape) 637 feed_dict[placeholder_tensor] = slice_value 638 639 assign_op = control_flow_ops.group(*assign_ops) 640 return assign_op, feed_dict 641# pylint: enable=protected-access 642 643 644def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False, 645 reshape_variables=False): 646 """Returns a function that assigns specific variables from a checkpoint. 647 648 If ignore_missing_vars is True and no variables are found in the checkpoint 649 it returns None. 650 651 Args: 652 model_path: The full path to the model checkpoint. To get latest checkpoint 653 use `model_path = tf.train.latest_checkpoint(checkpoint_dir)` 654 var_list: A list of `Variable` objects or a dictionary mapping names in the 655 checkpoint to the corresponding variables to initialize. If empty or 656 `None`, it would return `no_op(), None`. 657 ignore_missing_vars: Boolean, if True it would ignore variables missing in 658 the checkpoint with a warning instead of failing. 659 reshape_variables: Boolean, if True it would automatically reshape variables 660 which are of different shape then the ones stored in the checkpoint but 661 which have the same number of elements. 662 663 Returns: 664 A function that takes a single argument, a `tf.Session`, that applies the 665 assignment operation. If no matching variables were found in the checkpoint 666 then `None` is returned. 667 668 Raises: 669 ValueError: If var_list is empty. 670 """ 671 if not var_list: 672 raise ValueError('var_list cannot be empty') 673 if ignore_missing_vars: 674 reader = pywrap_tensorflow.NewCheckpointReader(model_path) 675 if isinstance(var_list, dict): 676 var_dict = var_list 677 else: 678 var_dict = {var.op.name: var for var in var_list} 679 available_vars = {} 680 for var in var_dict: 681 if reader.has_tensor(var): 682 available_vars[var] = var_dict[var] 683 else: 684 logging.warning( 685 'Variable %s missing in checkpoint %s', var, model_path) 686 var_list = available_vars 687 if var_list: 688 saver = tf_saver.Saver(var_list, reshape=reshape_variables, 689 write_version=saver_pb2.SaverDef.V1) 690 def callback(session): 691 saver.restore(session, model_path) 692 return callback 693 else: 694 logging.warning('No Variables to restore') 695 return None 696 697 698class VariableDeviceChooser(object): 699 """Device chooser for variables. 700 701 When using a parameter server it will assign them in a round-robin fashion. 702 When not using a parameter server it allows GPU or CPU placement. 703 """ 704 705 def __init__(self, 706 num_tasks=0, 707 job_name='ps', 708 device_type='CPU', 709 device_index=0): 710 """Initialize VariableDeviceChooser. 711 712 Usage: 713 To use with 2 parameter servers: 714 VariableDeviceChooser(2) 715 716 To use without parameter servers: 717 VariableDeviceChooser() 718 VariableDeviceChooser(device_type='GPU') # For GPU placement 719 720 Args: 721 num_tasks: number of tasks. 722 job_name: String, a name for the parameter server job. 723 device_type: Optional device type string (e.g. "CPU" or "GPU") 724 device_index: int. Optional device index. If left 725 unspecified, device represents 'any' device_index. 726 """ 727 self._job_name = job_name 728 self._device_type = device_type 729 self._device_index = device_index 730 self._num_tasks = num_tasks 731 self._next_task_id = 0 732 733 def __call__(self, op): 734 device_spec = tf_device.DeviceSpec(device_type=self._device_type, 735 device_index=self._device_index) 736 if self._num_tasks > 0: 737 task_id = self._next_task_id 738 self._next_task_id = (self._next_task_id + 1) % self._num_tasks 739 device_spec.job = self._job_name 740 device_spec.task = task_id 741 return device_spec.to_string() 742 743 744def filter_variables(var_list, include_patterns=None, exclude_patterns=None, 745 reg_search=True): 746 """Filter a list of variables using regular expressions. 747 748 First includes variables according to the list of include_patterns. 749 Afterwards, eliminates variables according to the list of exclude_patterns. 750 751 For example, one can obtain a list of variables with the weights of all 752 convolutional layers (depending on the network definition) by: 753 754 ```python 755 variables = tf.contrib.framework.get_model_variables() 756 conv_weight_variables = tf.contrib.framework.filter_variables( 757 variables, 758 include_patterns=['Conv'], 759 exclude_patterns=['biases', 'Logits']) 760 ``` 761 762 Args: 763 var_list: list of variables. 764 include_patterns: list of regular expressions to include. Defaults to None, 765 which means all variables are selected according to the include rules. 766 A variable is included if it matches any of the include_patterns. 767 exclude_patterns: list of regular expressions to exclude. Defaults to None, 768 which means all variables are selected according to the exclude rules. 769 A variable is excluded if it matches any of the exclude_patterns. 770 reg_search: boolean. If True (default), performs re.search to find matches 771 (i.e. pattern can match any substring of the variable name). If False, 772 performs re.match (i.e. regexp should match from the beginning of the 773 variable name). 774 775 Returns: 776 filtered list of variables. 777 """ 778 if reg_search: 779 reg_exp_func = re.search 780 else: 781 reg_exp_func = re.match 782 783 # First include variables. 784 if include_patterns is None: 785 included_variables = list(var_list) 786 else: 787 included_variables = [] 788 for var in var_list: 789 if any(reg_exp_func(ptrn, var.name) for ptrn in include_patterns): 790 included_variables.append(var) 791 792 # Afterwards, exclude variables. 793 if exclude_patterns is None: 794 filtered_variables = included_variables 795 else: 796 filtered_variables = [] 797 for var in included_variables: 798 if not any(reg_exp_func(ptrn, var.name) for ptrn in exclude_patterns): 799 filtered_variables.append(var) 800 801 return filtered_variables 802