1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Variable functions. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23import re 24 25from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope 26from tensorflow.contrib.framework.python.ops import gen_variable_ops 27from tensorflow.contrib.util import loader 28from tensorflow.core.protobuf import saver_pb2 29from tensorflow.python import pywrap_tensorflow 30from tensorflow.python.framework import device as tf_device 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import resource_variable_ops 36from tensorflow.python.ops import variable_scope 37from tensorflow.python.ops import variables 38from tensorflow.python.platform import resource_loader 39from tensorflow.python.platform import tf_logging as logging 40from tensorflow.python.training import saver as tf_saver 41from tensorflow.python.training import training_util 42from tensorflow.python.util.deprecation import deprecated 43 44 45__all__ = ['add_model_variable', 46 'assert_global_step', 47 'assert_or_get_global_step', 48 'assign_from_checkpoint', 49 'assign_from_checkpoint_fn', 50 'assign_from_values', 51 'assign_from_values_fn', 52 'create_global_step', 53 'filter_variables', 54 'get_global_step', 55 'get_or_create_global_step', 56 'get_local_variables', 57 'get_model_variables', 58 'get_trainable_variables', 59 'get_unique_variable', 60 'get_variables_by_name', 61 'get_variables_by_suffix', 62 'get_variable_full_name', 63 'get_variables_to_restore', 64 'get_variables', 65 'global_variable', 66 'local_variable', 67 'model_variable', 68 'variable', 69 'VariableDeviceChooser', 70 'zero_initializer'] 71 72 73def zero_initializer(ref, use_locking=True, name="zero_initializer"): 74 """Initialize 'ref' with all zeros, ref tensor should be uninitialized. 75 If already initialized, you will get ValueError. This op is intended to 76 save memory during initialization. 77 Args: 78 ref: ref of the tensor need to be zero initialized. 79 name: optional name for this operation. 80 Returns: 81 ref that initialized. 82 Raises: 83 ValueError: If ref tensor is initialized. 84 """ 85 loader.load_op_library( 86 resource_loader.get_path_to_datafile("_variable_ops.so")) 87 if resource_variable_ops.is_resource_variable(ref): 88 return gen_variable_ops.zero_var_initializer( 89 ref.handle, shape=ref.shape, dtype=ref.dtype, name=name) 90 else: 91 return gen_variable_ops.zero_initializer(ref, name=name) 92 93 94@deprecated(None, "Please switch to tf.train.assert_global_step") 95def assert_global_step(global_step_tensor): 96 training_util.assert_global_step(global_step_tensor) 97 98 99def assert_or_get_global_step(graph=None, global_step_tensor=None): 100 """Verifies that a global step tensor is valid or gets one if None is given. 101 102 If `global_step_tensor` is not None, check that it is a valid global step 103 tensor (using `assert_global_step`). Otherwise find a global step tensor using 104 `get_global_step` and return it. 105 106 Args: 107 graph: The graph to find the global step tensor for. 108 global_step_tensor: The tensor to check for suitability as a global step. 109 If None is given (the default), find a global step tensor. 110 111 Returns: 112 A tensor suitable as a global step, or `None` if none was provided and none 113 was found. 114 """ 115 if global_step_tensor is None: 116 # Get the global step tensor the same way the supervisor would. 117 global_step_tensor = get_global_step(graph) 118 else: 119 assert_global_step(global_step_tensor) 120 return global_step_tensor 121 122@deprecated(None, "Please switch to tf.train.get_global_step") 123def get_global_step(graph=None): 124 return training_util.get_global_step(graph) 125 126@deprecated(None, "Please switch to tf.train.create_global_step") 127def create_global_step(graph=None): 128 """Create global step tensor in graph. 129 130 This API is deprecated. Use core framework training version instead. 131 132 Args: 133 graph: The graph in which to create the global step tensor. If missing, 134 use default graph. 135 136 Returns: 137 Global step tensor. 138 139 Raises: 140 ValueError: if global step tensor is already defined. 141 """ 142 return training_util.create_global_step(graph) 143 144@deprecated(None, "Please switch to tf.train.get_or_create_global_step") 145def get_or_create_global_step(graph=None): 146 """Returns and create (if necessary) the global step tensor. 147 148 Args: 149 graph: The graph in which to create the global step tensor. If missing, use 150 default graph. 151 152 Returns: 153 The global step tensor. 154 """ 155 return training_util.get_or_create_global_step(graph) 156 157 158def local_variable(initial_value, 159 validate_shape=True, 160 name=None, 161 use_resource=None): 162 """Create a variable with a value and add it to `GraphKeys.LOCAL_VARIABLES`. 163 164 Args: 165 initial_value: See variables.Variable.__init__. 166 validate_shape: See variables.Variable.__init__. 167 name: See variables.Variable.__init__. 168 use_resource: If `True` use a ResourceVariable instead of a Variable. 169 Returns: 170 New variable. 171 """ 172 return variable_scope.variable( 173 initial_value, trainable=False, 174 collections=[ops.GraphKeys.LOCAL_VARIABLES], 175 validate_shape=validate_shape, 176 use_resource=use_resource, 177 name=name) 178 179 180def global_variable(initial_value, 181 validate_shape=True, 182 name=None, 183 use_resource=None): 184 """Create a variable with a value and add it to `GraphKeys.GLOBAL_VARIABLES`. 185 186 Args: 187 initial_value: See variables.Variable.__init__. 188 validate_shape: See variables.Variable.__init__. 189 name: See variables.Variable.__init__. 190 use_resource: If `True` use a ResourceVariable instead of a Variable. 191 Returns: 192 New variable. 193 """ 194 return variable_scope.variable( 195 initial_value, trainable=False, 196 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 197 validate_shape=validate_shape, 198 use_resource=use_resource, 199 name=name) 200 201 202@contrib_add_arg_scope 203def variable(name, 204 shape=None, 205 dtype=None, 206 initializer=None, 207 regularizer=None, 208 trainable=True, 209 collections=None, 210 caching_device=None, 211 device=None, 212 partitioner=None, 213 custom_getter=None, 214 use_resource=None, 215 synchronization=variables.VariableSynchronization.AUTO, 216 aggregation=variables.VariableAggregation.NONE): 217 """Gets an existing variable with these parameters or creates a new one. 218 219 Args: 220 name: the name of the new or existing variable. 221 shape: shape of the new or existing variable. 222 dtype: type of the new or existing variable (defaults to `DT_FLOAT`). 223 initializer: initializer for the variable if one is created. 224 regularizer: a (Tensor -> Tensor or None) function; the result of 225 applying it on a newly created variable will be added to the collection 226 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 227 trainable: If `True` also add the variable to the graph collection 228 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 229 collections: A list of collection names to which the Variable will be added. 230 If None it would default to `tf.GraphKeys.GLOBAL_VARIABLES`. 231 caching_device: Optional device string or function describing where the 232 Variable should be cached for reading. Defaults to the Variable's 233 device. 234 device: Optional device to place the variable. It can be an string or a 235 function that is called to get the device for the variable. 236 partitioner: Optional callable that accepts a fully defined `TensorShape` 237 and dtype of the `Variable` to be created, and returns a list of 238 partitions for each axis (currently only one axis can be partitioned). 239 custom_getter: Callable that allows overwriting the internal 240 get_variable method and has to have the same signature. 241 use_resource: If `True` use a ResourceVariable instead of a Variable. 242 synchronization: Indicates when a distributed a variable will be 243 aggregated. Accepted values are constants defined in the class 244 `tf.VariableSynchronization`. By default the synchronization is set to 245 `AUTO` and the current `DistributionStrategy` chooses 246 when to synchronize. If `synchronization` is set to `ON_READ`, 247 `trainable` must not be set to `True`. 248 aggregation: Indicates how a distributed variable will be aggregated. 249 Accepted values are constants defined in the class 250 `tf.VariableAggregation`. 251 252 Returns: 253 The created or existing variable. 254 """ 255 collections = list(collections if collections is not None 256 else [ops.GraphKeys.GLOBAL_VARIABLES]) 257 258 # Remove duplicates 259 collections = list(set(collections)) 260 getter = variable_scope.get_variable 261 if custom_getter is not None: 262 getter = functools.partial(custom_getter, 263 reuse=variable_scope.get_variable_scope().reuse) 264 with ops.device(device or ''): 265 return getter( 266 name, 267 shape=shape, 268 dtype=dtype, 269 initializer=initializer, 270 regularizer=regularizer, 271 trainable=trainable, 272 collections=collections, 273 caching_device=caching_device, 274 partitioner=partitioner, 275 use_resource=use_resource, 276 synchronization=synchronization, 277 aggregation=aggregation) 278 279 280@contrib_add_arg_scope 281def model_variable(name, 282 shape=None, 283 dtype=dtypes.float32, 284 initializer=None, 285 regularizer=None, 286 trainable=True, 287 collections=None, 288 caching_device=None, 289 device=None, 290 partitioner=None, 291 custom_getter=None, 292 use_resource=None, 293 synchronization=variables.VariableSynchronization.AUTO, 294 aggregation=variables.VariableAggregation.NONE): 295 """Gets an existing model variable with these parameters or creates a new one. 296 297 Args: 298 name: the name of the new or existing variable. 299 shape: shape of the new or existing variable. 300 dtype: type of the new or existing variable (defaults to `DT_FLOAT`). 301 initializer: initializer for the variable if one is created. 302 regularizer: a (Tensor -> Tensor or None) function; the result of 303 applying it on a newly created variable will be added to the collection 304 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 305 trainable: If `True` also add the variable to the graph collection 306 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 307 collections: A list of collection names to which the Variable will be added. 308 Note that the variable is always also added to the 309 `GraphKeys.GLOBAL_VARIABLES` and `GraphKeys.MODEL_VARIABLES` collections. 310 caching_device: Optional device string or function describing where the 311 Variable should be cached for reading. Defaults to the Variable's 312 device. 313 device: Optional device to place the variable. It can be an string or a 314 function that is called to get the device for the variable. 315 partitioner: Optional callable that accepts a fully defined `TensorShape` 316 and dtype of the `Variable` to be created, and returns a list of 317 partitions for each axis (currently only one axis can be partitioned). 318 custom_getter: Callable that allows overwriting the internal 319 get_variable method and has to have the same signature. 320 use_resource: If `True` use a ResourceVariable instead of a Variable. 321 synchronization: Indicates when a distributed a variable will be 322 aggregated. Accepted values are constants defined in the class 323 `tf.VariableSynchronization`. By default the synchronization is set to 324 `AUTO` and the current `DistributionStrategy` chooses 325 when to synchronize. If `synchronization` is set to `ON_READ`, 326 `trainable` must not be set to `True`. 327 aggregation: Indicates how a distributed variable will be aggregated. 328 Accepted values are constants defined in the class 329 `tf.VariableAggregation`. 330 331 Returns: 332 The created or existing variable. 333 """ 334 collections = list(collections or []) 335 collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES] 336 var = variable( 337 name, 338 shape=shape, 339 dtype=dtype, 340 initializer=initializer, 341 regularizer=regularizer, 342 trainable=trainable, 343 collections=collections, 344 caching_device=caching_device, 345 device=device, 346 partitioner=partitioner, 347 custom_getter=custom_getter, 348 use_resource=use_resource, 349 synchronization=synchronization, 350 aggregation=aggregation) 351 return var 352 353 354def add_model_variable(var): 355 """Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection. 356 357 Args: 358 var: a variable. 359 """ 360 if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES): 361 ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var) 362 363 364def get_variables(scope=None, suffix=None, 365 collection=ops.GraphKeys.GLOBAL_VARIABLES): 366 """Gets the list of variables, filtered by scope and/or suffix. 367 368 Args: 369 scope: an optional scope for filtering the variables to return. Can be a 370 variable scope or a string. 371 suffix: an optional suffix for filtering the variables to return. 372 collection: in which collection search for. Defaults to 373 `GraphKeys.GLOBAL_VARIABLES`. 374 375 Returns: 376 a list of variables in collection with scope and suffix. 377 """ 378 if isinstance(scope, variable_scope.VariableScope): 379 scope = scope.name 380 if suffix is not None: 381 if ':' not in suffix: 382 suffix += ':' 383 scope = (scope or '') + '.*' + suffix 384 return ops.get_collection(collection, scope) 385 386 387def get_model_variables(scope=None, suffix=None): 388 """Gets the list of model variables, filtered by scope and/or suffix. 389 390 Args: 391 scope: an optional scope for filtering the variables to return. 392 suffix: an optional suffix for filtering the variables to return. 393 394 Returns: 395 a list of variables in collection with scope and suffix. 396 """ 397 return get_variables(scope, suffix, ops.GraphKeys.MODEL_VARIABLES) 398 399 400def get_local_variables(scope=None, suffix=None): 401 """Gets the list of local variables, filtered by scope and/or suffix. 402 403 Args: 404 scope: an optional scope for filtering the variables to return. 405 suffix: an optional suffix for filtering the variables to return. 406 407 Returns: 408 a list of variables in collection with scope and suffix. 409 """ 410 return get_variables(scope, suffix, ops.GraphKeys.LOCAL_VARIABLES) 411 412 413def get_trainable_variables(scope=None, suffix=None): 414 """Gets the list of trainable variables, filtered by scope and/or suffix. 415 416 Args: 417 scope: an optional scope for filtering the variables to return. 418 suffix: an optional suffix for filtering the variables to return. 419 420 Returns: 421 a list of variables in the trainable collection with scope and suffix. 422 """ 423 return get_variables(scope, suffix, ops.GraphKeys.TRAINABLE_VARIABLES) 424 425 426def get_variables_to_restore(include=None, exclude=None): 427 """Gets the list of the variables to restore. 428 429 Args: 430 include: an optional list/tuple of scope strings for filtering which 431 variables from the VARIABLES collection to include. None would include all 432 the variables. 433 exclude: an optional list/tuple of scope strings for filtering which 434 variables from the VARIABLES collection to exclude. None it would not 435 exclude any. 436 437 Returns: 438 a list of variables to restore. 439 440 Raises: 441 TypeError: include or exclude is provided but is not a list or a tuple. 442 """ 443 if include is None: 444 # Include all variables. 445 vars_to_include = get_variables() 446 else: 447 if not isinstance(include, (list, tuple)): 448 raise TypeError('include is provided but is not a list or a tuple.') 449 vars_to_include = [] 450 for scope in include: 451 vars_to_include += get_variables(scope) 452 vars_to_exclude = set() 453 if exclude is not None: 454 if not isinstance(exclude, (list, tuple)): 455 raise TypeError('exclude is provided but is not a list or a tuple.') 456 for scope in exclude: 457 vars_to_exclude |= set(get_variables(scope)) 458 # Exclude the variables in vars_to_exclude 459 return [v for v in vars_to_include if v not in vars_to_exclude] 460 461 462def get_variables_by_suffix(suffix, scope=None): 463 """Gets the list of variables that end with the given suffix. 464 465 Args: 466 suffix: suffix for filtering the variables to return. 467 scope: an optional scope for filtering the variables to return. 468 469 Returns: 470 a copied list of variables with the given name and prefix. 471 """ 472 return get_variables(scope=scope, suffix=suffix) 473 474 475def get_variables_by_name(given_name, scope=None): 476 """Gets the list of variables that were given that name. 477 478 Args: 479 given_name: name given to the variable without any scope. 480 scope: an optional scope for filtering the variables to return. 481 482 Returns: 483 a copied list of variables with the given name and scope. 484 """ 485 suffix = '/' + given_name + ':|^' + given_name + ':' 486 return get_variables(scope=scope, suffix=suffix) 487 488 489def get_unique_variable(var_op_name): 490 """Gets the variable uniquely identified by that var_op_name. 491 492 Args: 493 var_op_name: the full name of the variable op, including the scope. 494 495 Returns: 496 a tensorflow variable. 497 498 Raises: 499 ValueError: if no variable uniquely identified by the name exists. 500 """ 501 candidates = get_variables(scope=var_op_name) 502 if not candidates: 503 raise ValueError('Couldn\'t find variable %s' % var_op_name) 504 505 for candidate in candidates: 506 if candidate.op.name == var_op_name: 507 return candidate 508 raise ValueError('Variable %s does not uniquely identify a variable' % 509 var_op_name) 510 511 512def assign_from_values(var_names_to_values): 513 """Creates an assignment operation from a given mapping. 514 515 This function provides a mechanism for performing assignment of variables 516 to values in a way that does not fill the graph with large assignment values. 517 518 Args: 519 var_names_to_values: A map from variable names to values. 520 521 Returns: 522 assign_op: An `Operation` that assigns each of the given variables to the 523 requested values. 524 feed_dict: The feed dictionary to use when evaluating `assign_op`. 525 526 Raises: 527 ValueError: if any of the given variable names were not found. 528 """ 529 feed_dict = {} 530 assign_ops = [] 531 532 for var_name in var_names_to_values: 533 var_value = var_names_to_values[var_name] 534 var = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, var_name) 535 if not var: 536 raise ValueError('Variable %s wasn\'t found' % var_name) 537 elif len(var) > 1: 538 # tf.get_collection is just a filter on the prefix: find the exact match: 539 found = False 540 for v in var: 541 if v.op.name == var_name: 542 var = v 543 found = True 544 break 545 546 if not found: 547 raise ValueError('Variable %s doesn\'t uniquely identify a variable' % 548 var_name) 549 else: 550 var = var[0] 551 552 # TODO(nsilberman): ensure placeholder and assign are on the same device. 553 # Assign a placeholder to the value that will be filled later. 554 placeholder_name = 'placeholder/' + var.op.name 555 placeholder_value = array_ops.placeholder( 556 dtype=var.dtype.base_dtype, 557 shape=var.get_shape(), 558 name=placeholder_name) 559 assign_ops.append(var.assign(placeholder_value)) 560 561 feed_dict[placeholder_value] = var_value.reshape(var.get_shape()) 562 563 assign_op = control_flow_ops.group(*assign_ops) 564 return assign_op, feed_dict 565 566 567def assign_from_values_fn(var_names_to_values): 568 """Returns a function that assigns specific variables from the given values. 569 570 This function provides a mechanism for performing assignment of variables 571 to values in a way that does not fill the graph with large assignment values. 572 573 Args: 574 var_names_to_values: A map from variable names to values. 575 576 Returns: 577 A function that takes a single argument, a `tf.Session`, that applies the 578 assignment operation. 579 580 Raises: 581 ValueError: if any of the given variable names were not found. 582 """ 583 assign_op, feed_dict = assign_from_values(var_names_to_values) 584 def callback(session): 585 return session.run(assign_op, feed_dict) 586 return callback 587 588 589# pylint: disable=protected-access 590# Currently variable_scope doesn't provide very good APIs to access 591# all variables under scope and retrieve and check existing scopes. 592def get_variable_full_name(var): 593 """Returns the full name of a variable. 594 595 For normal Variables, this is the same as the var.op.name. For 596 sliced or PartitionedVariables, this name is the same for all the 597 slices/partitions. In both cases, this is normally the name used in 598 a checkpoint file. 599 600 Args: 601 var: A `Variable` object. 602 603 Returns: 604 A string that is the full name. 605 """ 606 if var._save_slice_info: 607 return var._save_slice_info.full_name 608 else: 609 return var.op.name 610 611 612# TODO(nsilberman): add flag to load exponential moving averages instead 613# 614# TODO(sguada): Update docs in slim/g3doc/index.md to describe 615# the new feature where the var_list dictionary can have values that 616# are each a list of Variables. 617def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False): 618 """Creates an operation to assign specific variables from a checkpoint. 619 620 Args: 621 model_path: The full path to the model checkpoint. To get latest checkpoint 622 use `model_path = tf.train.latest_checkpoint(checkpoint_dir)` 623 var_list: A list of (possibly partitioned) `Variable` objects 624 or a dictionary mapping names in the checkpoint to the 625 corresponding variables or list of variables to initialize 626 from that checkpoint value. For partitioned Variables, the 627 name in the checkpoint must be the full variable, not the 628 name of the partitioned variable, eg. "my_var" rather than 629 "my_var/part_4". If empty, returns no_op(), {}. 630 ignore_missing_vars: Boolean, if True ignore variables missing in the 631 checkpoint with a warning instead of failing. 632 633 Returns: 634 the restore_op and the feed_dict that need to be run to restore var_list. 635 636 Raises: 637 ValueError: If `ignore_missing_vars` is False and the checkpoint specified 638 at `model_path` is missing one of the variables in `var_list`. 639 """ 640 # Normalize var_list into a dictionary mapping names in the 641 # checkpoint to the list of variables to initialize from that 642 # checkpoint variable. Sliced (including partitioned) variables will 643 # end up under the same key. 644 grouped_vars = {} 645 if isinstance(var_list, (tuple, list)): 646 for var in var_list: 647 ckpt_name = get_variable_full_name(var) 648 if ckpt_name not in grouped_vars: 649 grouped_vars[ckpt_name] = [] 650 grouped_vars[ckpt_name].append(var) 651 652 else: 653 for ckpt_name, value in var_list.items(): 654 if isinstance(value, (tuple, list)): 655 grouped_vars[ckpt_name] = value 656 else: 657 grouped_vars[ckpt_name] = [value] 658 659 # Read each checkpoint entry. Create a placeholder variable and 660 # add the (possibly sliced) data from the checkpoint to the feed_dict. 661 reader = pywrap_tensorflow.NewCheckpointReader(model_path) 662 feed_dict = {} 663 assign_ops = [] 664 for ckpt_name in grouped_vars: 665 if not reader.has_tensor(ckpt_name): 666 log_str = 'Checkpoint is missing variable [%s]' % ckpt_name 667 if ignore_missing_vars: 668 logging.warning(log_str) 669 continue 670 else: 671 raise ValueError(log_str) 672 ckpt_value = reader.get_tensor(ckpt_name) 673 674 for var in grouped_vars[ckpt_name]: 675 placeholder_tensor = array_ops.placeholder( 676 dtype=var.dtype.base_dtype, 677 shape=var.get_shape(), 678 name='placeholder/' + var.op.name) 679 assign_ops.append(var.assign(placeholder_tensor)) 680 681 if not var._save_slice_info: 682 if var.get_shape() != ckpt_value.shape: 683 raise ValueError( 684 'Total size of new array must be unchanged for %s ' 685 'lh_shape: [%s], rh_shape: [%s]' 686 % (ckpt_name, str(ckpt_value.shape), str(var.get_shape()))) 687 688 feed_dict[placeholder_tensor] = ckpt_value.reshape(ckpt_value.shape) 689 else: 690 slice_dims = zip(var._save_slice_info.var_offset, 691 var._save_slice_info.var_shape) 692 slice_dims = [(start, start + size) for (start, size) in slice_dims] 693 slice_dims = [slice(*x) for x in slice_dims] 694 slice_value = ckpt_value[slice_dims] 695 slice_value = slice_value.reshape(var._save_slice_info.var_shape) 696 feed_dict[placeholder_tensor] = slice_value 697 698 assign_op = control_flow_ops.group(*assign_ops) 699 return assign_op, feed_dict 700# pylint: enable=protected-access 701 702 703def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False, 704 reshape_variables=False): 705 """Returns a function that assigns specific variables from a checkpoint. 706 707 If ignore_missing_vars is True and no variables are found in the checkpoint 708 it returns None. 709 710 Args: 711 model_path: The full path to the model checkpoint. To get latest checkpoint 712 use `model_path = tf.train.latest_checkpoint(checkpoint_dir)` 713 var_list: A list of `Variable` objects or a dictionary mapping names in the 714 checkpoint to the corresponding variables to initialize. If empty or 715 `None`, it would return `no_op(), None`. 716 ignore_missing_vars: Boolean, if True it would ignore variables missing in 717 the checkpoint with a warning instead of failing. 718 reshape_variables: Boolean, if True it would automatically reshape variables 719 which are of different shape then the ones stored in the checkpoint but 720 which have the same number of elements. 721 722 Returns: 723 A function that takes a single argument, a `tf.Session`, that applies the 724 assignment operation. If no matching variables were found in the checkpoint 725 then `None` is returned. 726 727 Raises: 728 ValueError: If var_list is empty. 729 """ 730 if not var_list: 731 raise ValueError('var_list cannot be empty') 732 if ignore_missing_vars: 733 reader = pywrap_tensorflow.NewCheckpointReader(model_path) 734 if isinstance(var_list, dict): 735 var_dict = var_list 736 else: 737 var_dict = {var.op.name: var for var in var_list} 738 available_vars = {} 739 for var in var_dict: 740 if reader.has_tensor(var): 741 available_vars[var] = var_dict[var] 742 else: 743 logging.warning( 744 'Variable %s missing in checkpoint %s', var, model_path) 745 var_list = available_vars 746 if var_list: 747 saver = tf_saver.Saver(var_list, reshape=reshape_variables, 748 write_version=saver_pb2.SaverDef.V1) 749 def callback(session): 750 saver.restore(session, model_path) 751 return callback 752 else: 753 logging.warning('No Variables to restore') 754 return None 755 756 757class VariableDeviceChooser(object): 758 """Device chooser for variables. 759 760 When using a parameter server it will assign them in a round-robin fashion. 761 When not using a parameter server it allows GPU or CPU placement. 762 """ 763 764 def __init__(self, 765 num_tasks=0, 766 job_name='ps', 767 device_type='CPU', 768 device_index=0, 769 replica=None): 770 """Initialize VariableDeviceChooser. 771 772 Usage: 773 To use with 2 parameter servers: 774 VariableDeviceChooser(2) 775 776 To use without parameter servers: 777 VariableDeviceChooser() 778 VariableDeviceChooser(device_type='GPU') # For GPU placement 779 780 Args: 781 num_tasks: number of tasks. 782 job_name: String, a name for the parameter server job. 783 device_type: Optional device type string (e.g. "CPU" or "GPU") 784 device_index: int. Optional device index. If left 785 unspecified, device represents 'any' device_index. 786 """ 787 self._job_name = job_name 788 self._device_type = device_type 789 self._device_index = device_index 790 self._replica = replica 791 self._num_tasks = num_tasks 792 self._next_task_id = 0 793 794 def __call__(self, op): 795 device_spec = tf_device.DeviceSpec( 796 replica=self._replica, 797 device_type=self._device_type, 798 device_index=self._device_index) 799 if self._num_tasks > 0: 800 task_id = self._next_task_id 801 self._next_task_id = (self._next_task_id + 1) % self._num_tasks 802 device_spec.job = self._job_name 803 device_spec.task = task_id 804 return device_spec.to_string() 805 806 807def filter_variables(var_list, include_patterns=None, exclude_patterns=None, 808 reg_search=True): 809 """Filter a list of variables using regular expressions. 810 811 First includes variables according to the list of include_patterns. 812 Afterwards, eliminates variables according to the list of exclude_patterns. 813 814 For example, one can obtain a list of variables with the weights of all 815 convolutional layers (depending on the network definition) by: 816 817 ```python 818 variables = tf.contrib.framework.get_model_variables() 819 conv_weight_variables = tf.contrib.framework.filter_variables( 820 variables, 821 include_patterns=['Conv'], 822 exclude_patterns=['biases', 'Logits']) 823 ``` 824 825 Args: 826 var_list: list of variables. 827 include_patterns: list of regular expressions to include. Defaults to None, 828 which means all variables are selected according to the include rules. 829 A variable is included if it matches any of the include_patterns. 830 exclude_patterns: list of regular expressions to exclude. Defaults to None, 831 which means all variables are selected according to the exclude rules. 832 A variable is excluded if it matches any of the exclude_patterns. 833 reg_search: boolean. If True (default), performs re.search to find matches 834 (i.e. pattern can match any substring of the variable name). If False, 835 performs re.match (i.e. regexp should match from the beginning of the 836 variable name). 837 838 Returns: 839 filtered list of variables. 840 """ 841 if reg_search: 842 reg_exp_func = re.search 843 else: 844 reg_exp_func = re.match 845 846 # First include variables. 847 if include_patterns is None: 848 included_variables = list(var_list) 849 else: 850 included_variables = [] 851 for var in var_list: 852 if any(reg_exp_func(ptrn, var.name) for ptrn in include_patterns): 853 included_variables.append(var) 854 855 # Afterwards, exclude variables. 856 if exclude_patterns is None: 857 filtered_variables = included_variables 858 else: 859 filtered_variables = [] 860 for var in included_variables: 861 if not any(reg_exp_func(ptrn, var.name) for ptrn in exclude_patterns): 862 filtered_variables.append(var) 863 864 return filtered_variables 865