• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Variable functions.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23import re
24
25from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope
26from tensorflow.contrib.framework.python.ops import gen_variable_ops
27from tensorflow.contrib.util import loader
28from tensorflow.core.protobuf import saver_pb2
29from tensorflow.python import pywrap_tensorflow
30from tensorflow.python.framework import device as tf_device
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import variable_scope
36from tensorflow.python.platform import resource_loader
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.training import saver as tf_saver
39from tensorflow.python.training import training_util
40from tensorflow.python.util.deprecation import deprecated
41
42
43__all__ = ['add_model_variable',
44           'assert_global_step',
45           'assert_or_get_global_step',
46           'assign_from_checkpoint',
47           'assign_from_checkpoint_fn',
48           'assign_from_values',
49           'assign_from_values_fn',
50           'create_global_step',
51           'filter_variables',
52           'get_global_step',
53           'get_or_create_global_step',
54           'get_local_variables',
55           'get_model_variables',
56           'get_trainable_variables',
57           'get_unique_variable',
58           'get_variables_by_name',
59           'get_variables_by_suffix',
60           'get_variable_full_name',
61           'get_variables_to_restore',
62           'get_variables',
63           'global_variable',
64           'local_variable',
65           'model_variable',
66           'variable',
67           'VariableDeviceChooser',
68           'zero_initializer']
69
70
71def zero_initializer(ref, use_locking=True, name="zero_initializer"):
72  """Initialize 'ref' with all zeros, ref tensor should be uninitialized.
73  If already initialized, you will get ValueError. This op is intended to
74  save memory during initialization.
75  Args:
76    ref: ref of the tensor need to be zero initialized.
77    name: optional name for this operation.
78  Returns:
79    ref that initialized.
80  Raises:
81    ValueError: If ref tensor is initialized.
82  """
83  loader.load_op_library(
84      resource_loader.get_path_to_datafile("_variable_ops.so"))
85  return gen_variable_ops.zero_initializer(ref, name=name)
86
87@deprecated(None, "Please switch to tf.train.assert_global_step")
88def assert_global_step(global_step_tensor):
89  training_util.assert_global_step(global_step_tensor)
90
91
92def assert_or_get_global_step(graph=None, global_step_tensor=None):
93  """Verifies that a global step tensor is valid or gets one if None is given.
94
95  If `global_step_tensor` is not None, check that it is a valid global step
96  tensor (using `assert_global_step`). Otherwise find a global step tensor using
97  `get_global_step` and return it.
98
99  Args:
100    graph: The graph to find the global step tensor for.
101    global_step_tensor: The tensor to check for suitability as a global step.
102      If None is given (the default), find a global step tensor.
103
104  Returns:
105    A tensor suitable as a global step, or `None` if none was provided and none
106    was found.
107  """
108  if global_step_tensor is None:
109    # Get the global step tensor the same way the supervisor would.
110    global_step_tensor = get_global_step(graph)
111  else:
112    assert_global_step(global_step_tensor)
113  return global_step_tensor
114
115@deprecated(None, "Please switch to tf.train.get_global_step")
116def get_global_step(graph=None):
117  return training_util.get_global_step(graph)
118
119@deprecated(None, "Please switch to tf.train.create_global_step")
120def create_global_step(graph=None):
121  """Create global step tensor in graph.
122
123  This API is deprecated. Use core framework training version instead.
124
125  Args:
126    graph: The graph in which to create the global step tensor. If missing,
127      use default graph.
128
129  Returns:
130    Global step tensor.
131
132  Raises:
133    ValueError: if global step tensor is already defined.
134  """
135  return training_util.create_global_step(graph)
136
137@deprecated(None, "Please switch to tf.train.get_or_create_global_step")
138def get_or_create_global_step(graph=None):
139  """Returns and create (if necessary) the global step tensor.
140
141  Args:
142    graph: The graph in which to create the global step tensor. If missing, use
143      default graph.
144
145  Returns:
146    The global step tensor.
147  """
148  return training_util.get_or_create_global_step(graph)
149
150
151def local_variable(initial_value,
152                   validate_shape=True,
153                   name=None,
154                   use_resource=None):
155  """Create a variable with a value and add it to `GraphKeys.LOCAL_VARIABLES`.
156
157  Args:
158    initial_value: See variables.Variable.__init__.
159    validate_shape: See variables.Variable.__init__.
160    name: See variables.Variable.__init__.
161    use_resource: If `True` use a ResourceVariable instead of a Variable.
162  Returns:
163    New variable.
164  """
165  return variable_scope.variable(
166      initial_value, trainable=False,
167      collections=[ops.GraphKeys.LOCAL_VARIABLES],
168      validate_shape=validate_shape,
169      use_resource=use_resource,
170      name=name)
171
172
173def global_variable(initial_value,
174                    validate_shape=True,
175                    name=None,
176                    use_resource=None):
177  """Create a variable with a value and add it to `GraphKeys.GLOBAL_VARIABLES`.
178
179  Args:
180    initial_value: See variables.Variable.__init__.
181    validate_shape: See variables.Variable.__init__.
182    name: See variables.Variable.__init__.
183    use_resource: If `True` use a ResourceVariable instead of a Variable.
184  Returns:
185    New variable.
186  """
187  return variable_scope.variable(
188      initial_value, trainable=False,
189      collections=[ops.GraphKeys.GLOBAL_VARIABLES],
190      validate_shape=validate_shape,
191      use_resource=use_resource,
192      name=name)
193
194
195@contrib_add_arg_scope
196def variable(name, shape=None, dtype=None, initializer=None,
197             regularizer=None, trainable=True, collections=None,
198             caching_device=None, device=None,
199             partitioner=None, custom_getter=None, use_resource=None):
200  """Gets an existing variable with these parameters or creates a new one.
201
202  Args:
203    name: the name of the new or existing variable.
204    shape: shape of the new or existing variable.
205    dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
206    initializer: initializer for the variable if one is created.
207    regularizer: a (Tensor -> Tensor or None) function; the result of
208        applying it on a newly created variable will be added to the collection
209        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
210    trainable: If `True` also add the variable to the graph collection
211      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
212    collections: A list of collection names to which the Variable will be added.
213      If None it would default to `tf.GraphKeys.GLOBAL_VARIABLES`.
214    caching_device: Optional device string or function describing where the
215        Variable should be cached for reading.  Defaults to the Variable's
216        device.
217    device: Optional device to place the variable. It can be an string or a
218      function that is called to get the device for the variable.
219    partitioner: Optional callable that accepts a fully defined `TensorShape`
220      and dtype of the `Variable` to be created, and returns a list of
221      partitions for each axis (currently only one axis can be partitioned).
222    custom_getter: Callable that allows overwriting the internal
223      get_variable method and has to have the same signature.
224    use_resource: If `True` use a ResourceVariable instead of a Variable.
225
226  Returns:
227    The created or existing variable.
228  """
229  collections = list(collections if collections is not None
230                     else [ops.GraphKeys.GLOBAL_VARIABLES])
231
232  # Remove duplicates
233  collections = list(set(collections))
234  getter = variable_scope.get_variable
235  if custom_getter is not None:
236    getter = functools.partial(custom_getter,
237                               reuse=variable_scope.get_variable_scope().reuse)
238  with ops.device(device or ''):
239    return getter(name, shape=shape, dtype=dtype,
240                  initializer=initializer,
241                  regularizer=regularizer,
242                  trainable=trainable,
243                  collections=collections,
244                  caching_device=caching_device,
245                  partitioner=partitioner,
246                  use_resource=use_resource)
247
248
249@contrib_add_arg_scope
250def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
251                   regularizer=None, trainable=True, collections=None,
252                   caching_device=None, device=None, partitioner=None,
253                   custom_getter=None, use_resource=None):
254  """Gets an existing model variable with these parameters or creates a new one.
255
256  Args:
257    name: the name of the new or existing variable.
258    shape: shape of the new or existing variable.
259    dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
260    initializer: initializer for the variable if one is created.
261    regularizer: a (Tensor -> Tensor or None) function; the result of
262        applying it on a newly created variable will be added to the collection
263        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
264    trainable: If `True` also add the variable to the graph collection
265      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
266    collections: A list of collection names to which the Variable will be added.
267      Note that the variable is always also added to the
268      `GraphKeys.GLOBAL_VARIABLES` and `GraphKeys.MODEL_VARIABLES` collections.
269    caching_device: Optional device string or function describing where the
270        Variable should be cached for reading.  Defaults to the Variable's
271        device.
272    device: Optional device to place the variable. It can be an string or a
273      function that is called to get the device for the variable.
274    partitioner: Optional callable that accepts a fully defined `TensorShape`
275      and dtype of the `Variable` to be created, and returns a list of
276      partitions for each axis (currently only one axis can be partitioned).
277    custom_getter: Callable that allows overwriting the internal
278      get_variable method and has to have the same signature.
279    use_resource: If `True` use a ResourceVariable instead of a Variable.
280
281  Returns:
282    The created or existing variable.
283  """
284  collections = list(collections or [])
285  collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES]
286  var = variable(name, shape=shape, dtype=dtype,
287                 initializer=initializer, regularizer=regularizer,
288                 trainable=trainable, collections=collections,
289                 caching_device=caching_device, device=device,
290                 partitioner=partitioner, custom_getter=custom_getter,
291                 use_resource=use_resource)
292  return var
293
294
295def add_model_variable(var):
296  """Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection.
297
298  Args:
299    var: a variable.
300  """
301  if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES):
302    ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var)
303
304
305def get_variables(scope=None, suffix=None,
306                  collection=ops.GraphKeys.GLOBAL_VARIABLES):
307  """Gets the list of variables, filtered by scope and/or suffix.
308
309  Args:
310    scope: an optional scope for filtering the variables to return. Can be a
311      variable scope or a string.
312    suffix: an optional suffix for filtering the variables to return.
313    collection: in which collection search for. Defaults to
314      `GraphKeys.GLOBAL_VARIABLES`.
315
316  Returns:
317    a list of variables in collection with scope and suffix.
318  """
319  if isinstance(scope, variable_scope.VariableScope):
320    scope = scope.name
321  if suffix is not None:
322    if ':' not in suffix:
323      suffix += ':'
324    scope = (scope or '') + '.*' + suffix
325  return ops.get_collection(collection, scope)
326
327
328def get_model_variables(scope=None, suffix=None):
329  """Gets the list of model variables, filtered by scope and/or suffix.
330
331  Args:
332    scope: an optional scope for filtering the variables to return.
333    suffix: an optional suffix for filtering the variables to return.
334
335  Returns:
336    a list of variables in collection with scope and suffix.
337  """
338  return get_variables(scope, suffix, ops.GraphKeys.MODEL_VARIABLES)
339
340
341def get_local_variables(scope=None, suffix=None):
342  """Gets the list of local variables, filtered by scope and/or suffix.
343
344  Args:
345    scope: an optional scope for filtering the variables to return.
346    suffix: an optional suffix for filtering the variables to return.
347
348  Returns:
349    a list of variables in collection with scope and suffix.
350  """
351  return get_variables(scope, suffix, ops.GraphKeys.LOCAL_VARIABLES)
352
353
354def get_trainable_variables(scope=None, suffix=None):
355  """Gets the list of trainable variables, filtered by scope and/or suffix.
356
357  Args:
358    scope: an optional scope for filtering the variables to return.
359    suffix: an optional suffix for filtering the variables to return.
360
361  Returns:
362    a list of variables in the trainable collection with scope and suffix.
363  """
364  return get_variables(scope, suffix, ops.GraphKeys.TRAINABLE_VARIABLES)
365
366
367def get_variables_to_restore(include=None, exclude=None):
368  """Gets the list of the variables to restore.
369
370  Args:
371    include: an optional list/tuple of scope strings for filtering which
372      variables from the VARIABLES collection to include. None would include all
373      the variables.
374    exclude: an optional list/tuple of scope strings for filtering which
375      variables from the VARIABLES collection to exclude. None it would not
376      exclude any.
377
378  Returns:
379    a list of variables to restore.
380
381  Raises:
382    TypeError: include or exclude is provided but is not a list or a tuple.
383  """
384  if include is None:
385    # Include all variables.
386    vars_to_include = get_variables()
387  else:
388    if not isinstance(include, (list, tuple)):
389      raise TypeError('include is provided but is not a list or a tuple.')
390    vars_to_include = []
391    for scope in include:
392      vars_to_include += get_variables(scope)
393  vars_to_exclude = set()
394  if exclude is not None:
395    if not isinstance(exclude, (list, tuple)):
396      raise TypeError('exclude is provided but is not a list or a tuple.')
397    for scope in exclude:
398      vars_to_exclude |= set(get_variables(scope))
399  # Exclude the variables in vars_to_exclude
400  return [v for v in vars_to_include if v not in vars_to_exclude]
401
402
403def get_variables_by_suffix(suffix, scope=None):
404  """Gets the list of variables that end with the given suffix.
405
406  Args:
407    suffix: suffix for filtering the variables to return.
408    scope: an optional scope for filtering the variables to return.
409
410  Returns:
411    a copied list of variables with the given name and prefix.
412  """
413  return get_variables(scope=scope, suffix=suffix)
414
415
416def get_variables_by_name(given_name, scope=None):
417  """Gets the list of variables that were given that name.
418
419  Args:
420    given_name: name given to the variable without any scope.
421    scope: an optional scope for filtering the variables to return.
422
423  Returns:
424    a copied list of variables with the given name and scope.
425  """
426  suffix = '/' + given_name + ':|^' + given_name + ':'
427  return get_variables(scope=scope, suffix=suffix)
428
429
430def get_unique_variable(var_op_name):
431  """Gets the variable uniquely identified by that var_op_name.
432
433  Args:
434    var_op_name: the full name of the variable op, including the scope.
435
436  Returns:
437    a tensorflow variable.
438
439  Raises:
440    ValueError: if no variable uniquely identified by the name exists.
441  """
442  candidates = get_variables(scope=var_op_name)
443  if not candidates:
444    raise ValueError('Couldn\'t find variable %s' % var_op_name)
445
446  for candidate in candidates:
447    if candidate.op.name == var_op_name:
448      return candidate
449  raise ValueError('Variable %s does not uniquely identify a variable' %
450                   var_op_name)
451
452
453def assign_from_values(var_names_to_values):
454  """Creates an assignment operation from a given mapping.
455
456  This function provides a mechanism for performing assignment of variables
457  to values in a way that does not fill the graph with large assignment values.
458
459  Args:
460    var_names_to_values: A map from variable names to values.
461
462  Returns:
463    assign_op: An `Operation` that assigns each of the given variables to the
464      requested values.
465    feed_dict: The feed dictionary to use when evaluating `assign_op`.
466
467  Raises:
468    ValueError: if any of the given variable names were not found.
469  """
470  feed_dict = {}
471  assign_ops = []
472
473  for var_name in var_names_to_values:
474    var_value = var_names_to_values[var_name]
475    var = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, var_name)
476    if not var:
477      raise ValueError('Variable %s wasn\'t found' % var_name)
478    elif len(var) > 1:
479      # tf.get_collection is just a filter on the prefix: find the exact match:
480      found = False
481      for v in var:
482        if v.op.name == var_name:
483          var = v
484          found = True
485          break
486
487      if not found:
488        raise ValueError('Variable %s doesn\'t uniquely identify a variable' %
489                         var_name)
490    else:
491      var = var[0]
492
493    # TODO(nsilberman): ensure placeholder and assign are on the same device.
494    # Assign a placeholder to the value that will be filled later.
495    placeholder_name = 'placeholder/' + var.op.name
496    placeholder_value = array_ops.placeholder(
497        dtype=var.dtype.base_dtype,
498        shape=var.get_shape(),
499        name=placeholder_name)
500    assign_ops.append(var.assign(placeholder_value))
501
502    feed_dict[placeholder_value] = var_value.reshape(var.get_shape())
503
504  assign_op = control_flow_ops.group(*assign_ops)
505  return assign_op, feed_dict
506
507
508def assign_from_values_fn(var_names_to_values):
509  """Returns a function that assigns specific variables from the given values.
510
511  This function provides a mechanism for performing assignment of variables
512  to values in a way that does not fill the graph with large assignment values.
513
514  Args:
515    var_names_to_values: A map from variable names to values.
516
517  Returns:
518    A function that takes a single argument, a `tf.Session`, that applies the
519    assignment operation.
520
521  Raises:
522    ValueError: if any of the given variable names were not found.
523  """
524  assign_op, feed_dict = assign_from_values(var_names_to_values)
525  def callback(session):
526    return session.run(assign_op, feed_dict)
527  return callback
528
529
530# pylint: disable=protected-access
531# Currently variable_scope doesn't provide very good APIs to access
532# all variables under scope and retrieve and check existing scopes.
533def get_variable_full_name(var):
534  """Returns the full name of a variable.
535
536  For normal Variables, this is the same as the var.op.name.  For
537  sliced or PartitionedVariables, this name is the same for all the
538  slices/partitions. In both cases, this is normally the name used in
539  a checkpoint file.
540
541  Args:
542    var: A `Variable` object.
543
544  Returns:
545    A string that is the full name.
546  """
547  if var._save_slice_info:
548    return var._save_slice_info.full_name
549  else:
550    return var.op.name
551
552
553# TODO(nsilberman): add flag to load exponential moving averages instead
554#
555# TODO(sguada): Update docs in slim/g3doc/index.md to describe
556# the new feature where the var_list dictionary can have values that
557# are each a list of Variables.
558def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False):
559  """Creates an operation to assign specific variables from a checkpoint.
560
561  Args:
562    model_path: The full path to the model checkpoint. To get latest checkpoint
563        use `model_path = tf.train.latest_checkpoint(checkpoint_dir)`
564    var_list: A list of (possibly partitioned) `Variable` objects
565        or a dictionary mapping names in the checkpoint to the
566        corresponding variables or list of variables to initialize
567        from that checkpoint value. For partitioned Variables, the
568        name in the checkpoint must be the full variable, not the
569        name of the partitioned variable, eg. "my_var" rather than
570        "my_var/part_4". If empty, returns no_op(), {}.
571    ignore_missing_vars: Boolean, if True ignore variables missing in the
572        checkpoint with a warning instead of failing.
573
574  Returns:
575    the restore_op and the feed_dict that need to be run to restore var_list.
576
577  Raises:
578    ValueError: If `ignore_missing_vars` is False and the checkpoint specified
579        at `model_path` is missing one of the variables in `var_list`.
580  """
581  # Normalize var_list into a dictionary mapping names in the
582  # checkpoint to the list of variables to initialize from that
583  # checkpoint variable. Sliced (including partitioned) variables will
584  # end up under the same key.
585  grouped_vars = {}
586  if isinstance(var_list, (tuple, list)):
587    for var in var_list:
588      ckpt_name = get_variable_full_name(var)
589      if ckpt_name not in grouped_vars:
590        grouped_vars[ckpt_name] = []
591      grouped_vars[ckpt_name].append(var)
592
593  else:
594    for ckpt_name, value in var_list.items():
595      if isinstance(value, (tuple, list)):
596        grouped_vars[ckpt_name] = value
597      else:
598        grouped_vars[ckpt_name] = [value]
599
600  # Read each checkpoint entry. Create a placeholder variable and
601  # add the (possibly sliced) data from the checkpoint to the feed_dict.
602  reader = pywrap_tensorflow.NewCheckpointReader(model_path)
603  feed_dict = {}
604  assign_ops = []
605  for ckpt_name in grouped_vars:
606    if not reader.has_tensor(ckpt_name):
607      log_str = 'Checkpoint is missing variable [%s]' % ckpt_name
608      if ignore_missing_vars:
609        logging.warning(log_str)
610        continue
611      else:
612        raise ValueError(log_str)
613    ckpt_value = reader.get_tensor(ckpt_name)
614
615    for var in grouped_vars[ckpt_name]:
616      placeholder_tensor = array_ops.placeholder(
617          dtype=var.dtype.base_dtype,
618          shape=var.get_shape(),
619          name='placeholder/' + var.op.name)
620      assign_ops.append(var.assign(placeholder_tensor))
621
622      if not var._save_slice_info:
623        if var.get_shape() != ckpt_value.shape:
624          raise ValueError(
625              'Total size of new array must be unchanged for %s '
626              'lh_shape: [%s], rh_shape: [%s]'
627              % (ckpt_name, str(ckpt_value.shape), str(var.get_shape())))
628
629        feed_dict[placeholder_tensor] = ckpt_value.reshape(ckpt_value.shape)
630      else:
631        slice_dims = zip(var._save_slice_info.var_offset,
632                         var._save_slice_info.var_shape)
633        slice_dims = [(start, start + size) for (start, size) in slice_dims]
634        slice_dims = [slice(*x) for x in slice_dims]
635        slice_value = ckpt_value[slice_dims]
636        slice_value = slice_value.reshape(var._save_slice_info.var_shape)
637        feed_dict[placeholder_tensor] = slice_value
638
639  assign_op = control_flow_ops.group(*assign_ops)
640  return assign_op, feed_dict
641# pylint: enable=protected-access
642
643
644def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False,
645                              reshape_variables=False):
646  """Returns a function that assigns specific variables from a checkpoint.
647
648  If ignore_missing_vars is True and no variables are found in the checkpoint
649  it returns None.
650
651  Args:
652    model_path: The full path to the model checkpoint. To get latest checkpoint
653        use `model_path = tf.train.latest_checkpoint(checkpoint_dir)`
654    var_list: A list of `Variable` objects or a dictionary mapping names in the
655        checkpoint to the corresponding variables to initialize. If empty or
656        `None`, it would return `no_op(), None`.
657    ignore_missing_vars: Boolean, if True it would ignore variables missing in
658        the checkpoint with a warning instead of failing.
659    reshape_variables: Boolean, if True it would automatically reshape variables
660        which are of different shape then the ones stored in the checkpoint but
661        which have the same number of elements.
662
663  Returns:
664    A function that takes a single argument, a `tf.Session`, that applies the
665    assignment operation. If no matching variables were found in the checkpoint
666    then `None` is returned.
667
668  Raises:
669    ValueError: If var_list is empty.
670  """
671  if not var_list:
672    raise ValueError('var_list cannot be empty')
673  if ignore_missing_vars:
674    reader = pywrap_tensorflow.NewCheckpointReader(model_path)
675    if isinstance(var_list, dict):
676      var_dict = var_list
677    else:
678      var_dict = {var.op.name: var for var in var_list}
679    available_vars = {}
680    for var in var_dict:
681      if reader.has_tensor(var):
682        available_vars[var] = var_dict[var]
683      else:
684        logging.warning(
685            'Variable %s missing in checkpoint %s', var, model_path)
686    var_list = available_vars
687  if var_list:
688    saver = tf_saver.Saver(var_list, reshape=reshape_variables,
689                           write_version=saver_pb2.SaverDef.V1)
690    def callback(session):
691      saver.restore(session, model_path)
692    return callback
693  else:
694    logging.warning('No Variables to restore')
695    return None
696
697
698class VariableDeviceChooser(object):
699  """Device chooser for variables.
700
701  When using a parameter server it will assign them in a round-robin fashion.
702  When not using a parameter server it allows GPU or CPU placement.
703  """
704
705  def __init__(self,
706               num_tasks=0,
707               job_name='ps',
708               device_type='CPU',
709               device_index=0):
710    """Initialize VariableDeviceChooser.
711
712    Usage:
713      To use with 2 parameter servers:
714        VariableDeviceChooser(2)
715
716      To use without parameter servers:
717        VariableDeviceChooser()
718        VariableDeviceChooser(device_type='GPU') # For GPU placement
719
720    Args:
721      num_tasks: number of tasks.
722      job_name: String, a name for the parameter server job.
723      device_type: Optional device type string (e.g. "CPU" or "GPU")
724      device_index: int.  Optional device index.  If left
725        unspecified, device represents 'any' device_index.
726    """
727    self._job_name = job_name
728    self._device_type = device_type
729    self._device_index = device_index
730    self._num_tasks = num_tasks
731    self._next_task_id = 0
732
733  def __call__(self, op):
734    device_spec = tf_device.DeviceSpec(device_type=self._device_type,
735                                       device_index=self._device_index)
736    if self._num_tasks > 0:
737      task_id = self._next_task_id
738      self._next_task_id = (self._next_task_id + 1) % self._num_tasks
739      device_spec.job = self._job_name
740      device_spec.task = task_id
741    return device_spec.to_string()
742
743
744def filter_variables(var_list, include_patterns=None, exclude_patterns=None,
745                     reg_search=True):
746  """Filter a list of variables using regular expressions.
747
748  First includes variables according to the list of include_patterns.
749  Afterwards, eliminates variables according to the list of exclude_patterns.
750
751  For example, one can obtain a list of variables with the weights of all
752  convolutional layers (depending on the network definition) by:
753
754  ```python
755  variables = tf.contrib.framework.get_model_variables()
756  conv_weight_variables = tf.contrib.framework.filter_variables(
757      variables,
758      include_patterns=['Conv'],
759      exclude_patterns=['biases', 'Logits'])
760  ```
761
762  Args:
763    var_list: list of variables.
764    include_patterns: list of regular expressions to include. Defaults to None,
765        which means all variables are selected according to the include rules.
766        A variable is included if it matches any of the include_patterns.
767    exclude_patterns: list of regular expressions to exclude. Defaults to None,
768        which means all variables are selected according to the exclude rules.
769        A variable is excluded if it matches any of the exclude_patterns.
770    reg_search: boolean. If True (default), performs re.search to find matches
771        (i.e. pattern can match any substring of the variable name). If False,
772        performs re.match (i.e. regexp should match from the beginning of the
773        variable name).
774
775  Returns:
776    filtered list of variables.
777  """
778  if reg_search:
779    reg_exp_func = re.search
780  else:
781    reg_exp_func = re.match
782
783  # First include variables.
784  if include_patterns is None:
785    included_variables = list(var_list)
786  else:
787    included_variables = []
788    for var in var_list:
789      if any(reg_exp_func(ptrn, var.name) for ptrn in include_patterns):
790        included_variables.append(var)
791
792  # Afterwards, exclude variables.
793  if exclude_patterns is None:
794    filtered_variables = included_variables
795  else:
796    filtered_variables = []
797    for var in included_variables:
798      if not any(reg_exp_func(ptrn, var.name) for ptrn in exclude_patterns):
799        filtered_variables.append(var)
800
801  return filtered_variables
802