• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Variable functions.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23import re
24
25from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope
26from tensorflow.contrib.framework.python.ops import gen_variable_ops
27from tensorflow.contrib.util import loader
28from tensorflow.core.protobuf import saver_pb2
29from tensorflow.python import pywrap_tensorflow
30from tensorflow.python.framework import device as tf_device
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import resource_variable_ops
36from tensorflow.python.ops import variable_scope
37from tensorflow.python.ops import variables
38from tensorflow.python.platform import resource_loader
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.training import saver as tf_saver
41from tensorflow.python.training import training_util
42from tensorflow.python.util.deprecation import deprecated
43
44
45__all__ = ['add_model_variable',
46           'assert_global_step',
47           'assert_or_get_global_step',
48           'assign_from_checkpoint',
49           'assign_from_checkpoint_fn',
50           'assign_from_values',
51           'assign_from_values_fn',
52           'create_global_step',
53           'filter_variables',
54           'get_global_step',
55           'get_or_create_global_step',
56           'get_local_variables',
57           'get_model_variables',
58           'get_trainable_variables',
59           'get_unique_variable',
60           'get_variables_by_name',
61           'get_variables_by_suffix',
62           'get_variable_full_name',
63           'get_variables_to_restore',
64           'get_variables',
65           'global_variable',
66           'local_variable',
67           'model_variable',
68           'variable',
69           'VariableDeviceChooser',
70           'zero_initializer']
71
72
73def zero_initializer(ref, use_locking=True, name="zero_initializer"):
74  """Initialize 'ref' with all zeros, ref tensor should be uninitialized.
75  If already initialized, you will get ValueError. This op is intended to
76  save memory during initialization.
77  Args:
78    ref: ref of the tensor need to be zero initialized.
79    name: optional name for this operation.
80  Returns:
81    ref that initialized.
82  Raises:
83    ValueError: If ref tensor is initialized.
84  """
85  loader.load_op_library(
86      resource_loader.get_path_to_datafile("_variable_ops.so"))
87  if resource_variable_ops.is_resource_variable(ref):
88    return gen_variable_ops.zero_var_initializer(
89        ref.handle, shape=ref.shape, dtype=ref.dtype, name=name)
90  else:
91    return gen_variable_ops.zero_initializer(ref, name=name)
92
93
94@deprecated(None, "Please switch to tf.train.assert_global_step")
95def assert_global_step(global_step_tensor):
96  training_util.assert_global_step(global_step_tensor)
97
98
99def assert_or_get_global_step(graph=None, global_step_tensor=None):
100  """Verifies that a global step tensor is valid or gets one if None is given.
101
102  If `global_step_tensor` is not None, check that it is a valid global step
103  tensor (using `assert_global_step`). Otherwise find a global step tensor using
104  `get_global_step` and return it.
105
106  Args:
107    graph: The graph to find the global step tensor for.
108    global_step_tensor: The tensor to check for suitability as a global step.
109      If None is given (the default), find a global step tensor.
110
111  Returns:
112    A tensor suitable as a global step, or `None` if none was provided and none
113    was found.
114  """
115  if global_step_tensor is None:
116    # Get the global step tensor the same way the supervisor would.
117    global_step_tensor = get_global_step(graph)
118  else:
119    assert_global_step(global_step_tensor)
120  return global_step_tensor
121
122@deprecated(None, "Please switch to tf.train.get_global_step")
123def get_global_step(graph=None):
124  return training_util.get_global_step(graph)
125
126@deprecated(None, "Please switch to tf.train.create_global_step")
127def create_global_step(graph=None):
128  """Create global step tensor in graph.
129
130  This API is deprecated. Use core framework training version instead.
131
132  Args:
133    graph: The graph in which to create the global step tensor. If missing,
134      use default graph.
135
136  Returns:
137    Global step tensor.
138
139  Raises:
140    ValueError: if global step tensor is already defined.
141  """
142  return training_util.create_global_step(graph)
143
144@deprecated(None, "Please switch to tf.train.get_or_create_global_step")
145def get_or_create_global_step(graph=None):
146  """Returns and create (if necessary) the global step tensor.
147
148  Args:
149    graph: The graph in which to create the global step tensor. If missing, use
150      default graph.
151
152  Returns:
153    The global step tensor.
154  """
155  return training_util.get_or_create_global_step(graph)
156
157
158def local_variable(initial_value,
159                   validate_shape=True,
160                   name=None,
161                   use_resource=None):
162  """Create a variable with a value and add it to `GraphKeys.LOCAL_VARIABLES`.
163
164  Args:
165    initial_value: See variables.Variable.__init__.
166    validate_shape: See variables.Variable.__init__.
167    name: See variables.Variable.__init__.
168    use_resource: If `True` use a ResourceVariable instead of a Variable.
169  Returns:
170    New variable.
171  """
172  return variable_scope.variable(
173      initial_value, trainable=False,
174      collections=[ops.GraphKeys.LOCAL_VARIABLES],
175      validate_shape=validate_shape,
176      use_resource=use_resource,
177      name=name)
178
179
180def global_variable(initial_value,
181                    validate_shape=True,
182                    name=None,
183                    use_resource=None):
184  """Create a variable with a value and add it to `GraphKeys.GLOBAL_VARIABLES`.
185
186  Args:
187    initial_value: See variables.Variable.__init__.
188    validate_shape: See variables.Variable.__init__.
189    name: See variables.Variable.__init__.
190    use_resource: If `True` use a ResourceVariable instead of a Variable.
191  Returns:
192    New variable.
193  """
194  return variable_scope.variable(
195      initial_value, trainable=False,
196      collections=[ops.GraphKeys.GLOBAL_VARIABLES],
197      validate_shape=validate_shape,
198      use_resource=use_resource,
199      name=name)
200
201
202@contrib_add_arg_scope
203def variable(name,
204             shape=None,
205             dtype=None,
206             initializer=None,
207             regularizer=None,
208             trainable=True,
209             collections=None,
210             caching_device=None,
211             device=None,
212             partitioner=None,
213             custom_getter=None,
214             use_resource=None,
215             synchronization=variables.VariableSynchronization.AUTO,
216             aggregation=variables.VariableAggregation.NONE):
217  """Gets an existing variable with these parameters or creates a new one.
218
219  Args:
220    name: the name of the new or existing variable.
221    shape: shape of the new or existing variable.
222    dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
223    initializer: initializer for the variable if one is created.
224    regularizer: a (Tensor -> Tensor or None) function; the result of
225        applying it on a newly created variable will be added to the collection
226        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
227    trainable: If `True` also add the variable to the graph collection
228      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
229    collections: A list of collection names to which the Variable will be added.
230      If None it would default to `tf.GraphKeys.GLOBAL_VARIABLES`.
231    caching_device: Optional device string or function describing where the
232        Variable should be cached for reading.  Defaults to the Variable's
233        device.
234    device: Optional device to place the variable. It can be an string or a
235      function that is called to get the device for the variable.
236    partitioner: Optional callable that accepts a fully defined `TensorShape`
237      and dtype of the `Variable` to be created, and returns a list of
238      partitions for each axis (currently only one axis can be partitioned).
239    custom_getter: Callable that allows overwriting the internal
240      get_variable method and has to have the same signature.
241    use_resource: If `True` use a ResourceVariable instead of a Variable.
242    synchronization: Indicates when a distributed a variable will be
243      aggregated. Accepted values are constants defined in the class
244      `tf.VariableSynchronization`. By default the synchronization is set to
245      `AUTO` and the current `DistributionStrategy` chooses
246      when to synchronize. If `synchronization` is set to `ON_READ`,
247      `trainable` must not be set to `True`.
248    aggregation: Indicates how a distributed variable will be aggregated.
249      Accepted values are constants defined in the class
250      `tf.VariableAggregation`.
251
252  Returns:
253    The created or existing variable.
254  """
255  collections = list(collections if collections is not None
256                     else [ops.GraphKeys.GLOBAL_VARIABLES])
257
258  # Remove duplicates
259  collections = list(set(collections))
260  getter = variable_scope.get_variable
261  if custom_getter is not None:
262    getter = functools.partial(custom_getter,
263                               reuse=variable_scope.get_variable_scope().reuse)
264  with ops.device(device or ''):
265    return getter(
266        name,
267        shape=shape,
268        dtype=dtype,
269        initializer=initializer,
270        regularizer=regularizer,
271        trainable=trainable,
272        collections=collections,
273        caching_device=caching_device,
274        partitioner=partitioner,
275        use_resource=use_resource,
276        synchronization=synchronization,
277        aggregation=aggregation)
278
279
280@contrib_add_arg_scope
281def model_variable(name,
282                   shape=None,
283                   dtype=dtypes.float32,
284                   initializer=None,
285                   regularizer=None,
286                   trainable=True,
287                   collections=None,
288                   caching_device=None,
289                   device=None,
290                   partitioner=None,
291                   custom_getter=None,
292                   use_resource=None,
293                   synchronization=variables.VariableSynchronization.AUTO,
294                   aggregation=variables.VariableAggregation.NONE):
295  """Gets an existing model variable with these parameters or creates a new one.
296
297  Args:
298    name: the name of the new or existing variable.
299    shape: shape of the new or existing variable.
300    dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
301    initializer: initializer for the variable if one is created.
302    regularizer: a (Tensor -> Tensor or None) function; the result of
303        applying it on a newly created variable will be added to the collection
304        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
305    trainable: If `True` also add the variable to the graph collection
306      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
307    collections: A list of collection names to which the Variable will be added.
308      Note that the variable is always also added to the
309      `GraphKeys.GLOBAL_VARIABLES` and `GraphKeys.MODEL_VARIABLES` collections.
310    caching_device: Optional device string or function describing where the
311        Variable should be cached for reading.  Defaults to the Variable's
312        device.
313    device: Optional device to place the variable. It can be an string or a
314      function that is called to get the device for the variable.
315    partitioner: Optional callable that accepts a fully defined `TensorShape`
316      and dtype of the `Variable` to be created, and returns a list of
317      partitions for each axis (currently only one axis can be partitioned).
318    custom_getter: Callable that allows overwriting the internal
319      get_variable method and has to have the same signature.
320    use_resource: If `True` use a ResourceVariable instead of a Variable.
321    synchronization: Indicates when a distributed a variable will be
322      aggregated. Accepted values are constants defined in the class
323      `tf.VariableSynchronization`. By default the synchronization is set to
324      `AUTO` and the current `DistributionStrategy` chooses
325      when to synchronize. If `synchronization` is set to `ON_READ`,
326      `trainable` must not be set to `True`.
327    aggregation: Indicates how a distributed variable will be aggregated.
328      Accepted values are constants defined in the class
329      `tf.VariableAggregation`.
330
331  Returns:
332    The created or existing variable.
333  """
334  collections = list(collections or [])
335  collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES]
336  var = variable(
337      name,
338      shape=shape,
339      dtype=dtype,
340      initializer=initializer,
341      regularizer=regularizer,
342      trainable=trainable,
343      collections=collections,
344      caching_device=caching_device,
345      device=device,
346      partitioner=partitioner,
347      custom_getter=custom_getter,
348      use_resource=use_resource,
349      synchronization=synchronization,
350      aggregation=aggregation)
351  return var
352
353
354def add_model_variable(var):
355  """Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection.
356
357  Args:
358    var: a variable.
359  """
360  if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES):
361    ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var)
362
363
364def get_variables(scope=None, suffix=None,
365                  collection=ops.GraphKeys.GLOBAL_VARIABLES):
366  """Gets the list of variables, filtered by scope and/or suffix.
367
368  Args:
369    scope: an optional scope for filtering the variables to return. Can be a
370      variable scope or a string.
371    suffix: an optional suffix for filtering the variables to return.
372    collection: in which collection search for. Defaults to
373      `GraphKeys.GLOBAL_VARIABLES`.
374
375  Returns:
376    a list of variables in collection with scope and suffix.
377  """
378  if isinstance(scope, variable_scope.VariableScope):
379    scope = scope.name
380  if suffix is not None:
381    if ':' not in suffix:
382      suffix += ':'
383    scope = (scope or '') + '.*' + suffix
384  return ops.get_collection(collection, scope)
385
386
387def get_model_variables(scope=None, suffix=None):
388  """Gets the list of model variables, filtered by scope and/or suffix.
389
390  Args:
391    scope: an optional scope for filtering the variables to return.
392    suffix: an optional suffix for filtering the variables to return.
393
394  Returns:
395    a list of variables in collection with scope and suffix.
396  """
397  return get_variables(scope, suffix, ops.GraphKeys.MODEL_VARIABLES)
398
399
400def get_local_variables(scope=None, suffix=None):
401  """Gets the list of local variables, filtered by scope and/or suffix.
402
403  Args:
404    scope: an optional scope for filtering the variables to return.
405    suffix: an optional suffix for filtering the variables to return.
406
407  Returns:
408    a list of variables in collection with scope and suffix.
409  """
410  return get_variables(scope, suffix, ops.GraphKeys.LOCAL_VARIABLES)
411
412
413def get_trainable_variables(scope=None, suffix=None):
414  """Gets the list of trainable variables, filtered by scope and/or suffix.
415
416  Args:
417    scope: an optional scope for filtering the variables to return.
418    suffix: an optional suffix for filtering the variables to return.
419
420  Returns:
421    a list of variables in the trainable collection with scope and suffix.
422  """
423  return get_variables(scope, suffix, ops.GraphKeys.TRAINABLE_VARIABLES)
424
425
426def get_variables_to_restore(include=None, exclude=None):
427  """Gets the list of the variables to restore.
428
429  Args:
430    include: an optional list/tuple of scope strings for filtering which
431      variables from the VARIABLES collection to include. None would include all
432      the variables.
433    exclude: an optional list/tuple of scope strings for filtering which
434      variables from the VARIABLES collection to exclude. None it would not
435      exclude any.
436
437  Returns:
438    a list of variables to restore.
439
440  Raises:
441    TypeError: include or exclude is provided but is not a list or a tuple.
442  """
443  if include is None:
444    # Include all variables.
445    vars_to_include = get_variables()
446  else:
447    if not isinstance(include, (list, tuple)):
448      raise TypeError('include is provided but is not a list or a tuple.')
449    vars_to_include = []
450    for scope in include:
451      vars_to_include += get_variables(scope)
452  vars_to_exclude = set()
453  if exclude is not None:
454    if not isinstance(exclude, (list, tuple)):
455      raise TypeError('exclude is provided but is not a list or a tuple.')
456    for scope in exclude:
457      vars_to_exclude |= set(get_variables(scope))
458  # Exclude the variables in vars_to_exclude
459  return [v for v in vars_to_include if v not in vars_to_exclude]
460
461
462def get_variables_by_suffix(suffix, scope=None):
463  """Gets the list of variables that end with the given suffix.
464
465  Args:
466    suffix: suffix for filtering the variables to return.
467    scope: an optional scope for filtering the variables to return.
468
469  Returns:
470    a copied list of variables with the given name and prefix.
471  """
472  return get_variables(scope=scope, suffix=suffix)
473
474
475def get_variables_by_name(given_name, scope=None):
476  """Gets the list of variables that were given that name.
477
478  Args:
479    given_name: name given to the variable without any scope.
480    scope: an optional scope for filtering the variables to return.
481
482  Returns:
483    a copied list of variables with the given name and scope.
484  """
485  suffix = '/' + given_name + ':|^' + given_name + ':'
486  return get_variables(scope=scope, suffix=suffix)
487
488
489def get_unique_variable(var_op_name):
490  """Gets the variable uniquely identified by that var_op_name.
491
492  Args:
493    var_op_name: the full name of the variable op, including the scope.
494
495  Returns:
496    a tensorflow variable.
497
498  Raises:
499    ValueError: if no variable uniquely identified by the name exists.
500  """
501  candidates = get_variables(scope=var_op_name)
502  if not candidates:
503    raise ValueError('Couldn\'t find variable %s' % var_op_name)
504
505  for candidate in candidates:
506    if candidate.op.name == var_op_name:
507      return candidate
508  raise ValueError('Variable %s does not uniquely identify a variable' %
509                   var_op_name)
510
511
512def assign_from_values(var_names_to_values):
513  """Creates an assignment operation from a given mapping.
514
515  This function provides a mechanism for performing assignment of variables
516  to values in a way that does not fill the graph with large assignment values.
517
518  Args:
519    var_names_to_values: A map from variable names to values.
520
521  Returns:
522    assign_op: An `Operation` that assigns each of the given variables to the
523      requested values.
524    feed_dict: The feed dictionary to use when evaluating `assign_op`.
525
526  Raises:
527    ValueError: if any of the given variable names were not found.
528  """
529  feed_dict = {}
530  assign_ops = []
531
532  for var_name in var_names_to_values:
533    var_value = var_names_to_values[var_name]
534    var = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, var_name)
535    if not var:
536      raise ValueError('Variable %s wasn\'t found' % var_name)
537    elif len(var) > 1:
538      # tf.get_collection is just a filter on the prefix: find the exact match:
539      found = False
540      for v in var:
541        if v.op.name == var_name:
542          var = v
543          found = True
544          break
545
546      if not found:
547        raise ValueError('Variable %s doesn\'t uniquely identify a variable' %
548                         var_name)
549    else:
550      var = var[0]
551
552    # TODO(nsilberman): ensure placeholder and assign are on the same device.
553    # Assign a placeholder to the value that will be filled later.
554    placeholder_name = 'placeholder/' + var.op.name
555    placeholder_value = array_ops.placeholder(
556        dtype=var.dtype.base_dtype,
557        shape=var.get_shape(),
558        name=placeholder_name)
559    assign_ops.append(var.assign(placeholder_value))
560
561    feed_dict[placeholder_value] = var_value.reshape(var.get_shape())
562
563  assign_op = control_flow_ops.group(*assign_ops)
564  return assign_op, feed_dict
565
566
567def assign_from_values_fn(var_names_to_values):
568  """Returns a function that assigns specific variables from the given values.
569
570  This function provides a mechanism for performing assignment of variables
571  to values in a way that does not fill the graph with large assignment values.
572
573  Args:
574    var_names_to_values: A map from variable names to values.
575
576  Returns:
577    A function that takes a single argument, a `tf.Session`, that applies the
578    assignment operation.
579
580  Raises:
581    ValueError: if any of the given variable names were not found.
582  """
583  assign_op, feed_dict = assign_from_values(var_names_to_values)
584  def callback(session):
585    return session.run(assign_op, feed_dict)
586  return callback
587
588
589# pylint: disable=protected-access
590# Currently variable_scope doesn't provide very good APIs to access
591# all variables under scope and retrieve and check existing scopes.
592def get_variable_full_name(var):
593  """Returns the full name of a variable.
594
595  For normal Variables, this is the same as the var.op.name.  For
596  sliced or PartitionedVariables, this name is the same for all the
597  slices/partitions. In both cases, this is normally the name used in
598  a checkpoint file.
599
600  Args:
601    var: A `Variable` object.
602
603  Returns:
604    A string that is the full name.
605  """
606  if var._save_slice_info:
607    return var._save_slice_info.full_name
608  else:
609    return var.op.name
610
611
612# TODO(nsilberman): add flag to load exponential moving averages instead
613#
614# TODO(sguada): Update docs in slim/g3doc/index.md to describe
615# the new feature where the var_list dictionary can have values that
616# are each a list of Variables.
617def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False):
618  """Creates an operation to assign specific variables from a checkpoint.
619
620  Args:
621    model_path: The full path to the model checkpoint. To get latest checkpoint
622        use `model_path = tf.train.latest_checkpoint(checkpoint_dir)`
623    var_list: A list of (possibly partitioned) `Variable` objects
624        or a dictionary mapping names in the checkpoint to the
625        corresponding variables or list of variables to initialize
626        from that checkpoint value. For partitioned Variables, the
627        name in the checkpoint must be the full variable, not the
628        name of the partitioned variable, eg. "my_var" rather than
629        "my_var/part_4". If empty, returns no_op(), {}.
630    ignore_missing_vars: Boolean, if True ignore variables missing in the
631        checkpoint with a warning instead of failing.
632
633  Returns:
634    the restore_op and the feed_dict that need to be run to restore var_list.
635
636  Raises:
637    ValueError: If `ignore_missing_vars` is False and the checkpoint specified
638        at `model_path` is missing one of the variables in `var_list`.
639  """
640  # Normalize var_list into a dictionary mapping names in the
641  # checkpoint to the list of variables to initialize from that
642  # checkpoint variable. Sliced (including partitioned) variables will
643  # end up under the same key.
644  grouped_vars = {}
645  if isinstance(var_list, (tuple, list)):
646    for var in var_list:
647      ckpt_name = get_variable_full_name(var)
648      if ckpt_name not in grouped_vars:
649        grouped_vars[ckpt_name] = []
650      grouped_vars[ckpt_name].append(var)
651
652  else:
653    for ckpt_name, value in var_list.items():
654      if isinstance(value, (tuple, list)):
655        grouped_vars[ckpt_name] = value
656      else:
657        grouped_vars[ckpt_name] = [value]
658
659  # Read each checkpoint entry. Create a placeholder variable and
660  # add the (possibly sliced) data from the checkpoint to the feed_dict.
661  reader = pywrap_tensorflow.NewCheckpointReader(model_path)
662  feed_dict = {}
663  assign_ops = []
664  for ckpt_name in grouped_vars:
665    if not reader.has_tensor(ckpt_name):
666      log_str = 'Checkpoint is missing variable [%s]' % ckpt_name
667      if ignore_missing_vars:
668        logging.warning(log_str)
669        continue
670      else:
671        raise ValueError(log_str)
672    ckpt_value = reader.get_tensor(ckpt_name)
673
674    for var in grouped_vars[ckpt_name]:
675      placeholder_tensor = array_ops.placeholder(
676          dtype=var.dtype.base_dtype,
677          shape=var.get_shape(),
678          name='placeholder/' + var.op.name)
679      assign_ops.append(var.assign(placeholder_tensor))
680
681      if not var._save_slice_info:
682        if var.get_shape() != ckpt_value.shape:
683          raise ValueError(
684              'Total size of new array must be unchanged for %s '
685              'lh_shape: [%s], rh_shape: [%s]'
686              % (ckpt_name, str(ckpt_value.shape), str(var.get_shape())))
687
688        feed_dict[placeholder_tensor] = ckpt_value.reshape(ckpt_value.shape)
689      else:
690        slice_dims = zip(var._save_slice_info.var_offset,
691                         var._save_slice_info.var_shape)
692        slice_dims = [(start, start + size) for (start, size) in slice_dims]
693        slice_dims = [slice(*x) for x in slice_dims]
694        slice_value = ckpt_value[slice_dims]
695        slice_value = slice_value.reshape(var._save_slice_info.var_shape)
696        feed_dict[placeholder_tensor] = slice_value
697
698  assign_op = control_flow_ops.group(*assign_ops)
699  return assign_op, feed_dict
700# pylint: enable=protected-access
701
702
703def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False,
704                              reshape_variables=False):
705  """Returns a function that assigns specific variables from a checkpoint.
706
707  If ignore_missing_vars is True and no variables are found in the checkpoint
708  it returns None.
709
710  Args:
711    model_path: The full path to the model checkpoint. To get latest checkpoint
712        use `model_path = tf.train.latest_checkpoint(checkpoint_dir)`
713    var_list: A list of `Variable` objects or a dictionary mapping names in the
714        checkpoint to the corresponding variables to initialize. If empty or
715        `None`, it would return `no_op(), None`.
716    ignore_missing_vars: Boolean, if True it would ignore variables missing in
717        the checkpoint with a warning instead of failing.
718    reshape_variables: Boolean, if True it would automatically reshape variables
719        which are of different shape then the ones stored in the checkpoint but
720        which have the same number of elements.
721
722  Returns:
723    A function that takes a single argument, a `tf.Session`, that applies the
724    assignment operation. If no matching variables were found in the checkpoint
725    then `None` is returned.
726
727  Raises:
728    ValueError: If var_list is empty.
729  """
730  if not var_list:
731    raise ValueError('var_list cannot be empty')
732  if ignore_missing_vars:
733    reader = pywrap_tensorflow.NewCheckpointReader(model_path)
734    if isinstance(var_list, dict):
735      var_dict = var_list
736    else:
737      var_dict = {var.op.name: var for var in var_list}
738    available_vars = {}
739    for var in var_dict:
740      if reader.has_tensor(var):
741        available_vars[var] = var_dict[var]
742      else:
743        logging.warning(
744            'Variable %s missing in checkpoint %s', var, model_path)
745    var_list = available_vars
746  if var_list:
747    saver = tf_saver.Saver(var_list, reshape=reshape_variables,
748                           write_version=saver_pb2.SaverDef.V1)
749    def callback(session):
750      saver.restore(session, model_path)
751    return callback
752  else:
753    logging.warning('No Variables to restore')
754    return None
755
756
757class VariableDeviceChooser(object):
758  """Device chooser for variables.
759
760  When using a parameter server it will assign them in a round-robin fashion.
761  When not using a parameter server it allows GPU or CPU placement.
762  """
763
764  def __init__(self,
765               num_tasks=0,
766               job_name='ps',
767               device_type='CPU',
768               device_index=0,
769               replica=None):
770    """Initialize VariableDeviceChooser.
771
772    Usage:
773      To use with 2 parameter servers:
774        VariableDeviceChooser(2)
775
776      To use without parameter servers:
777        VariableDeviceChooser()
778        VariableDeviceChooser(device_type='GPU') # For GPU placement
779
780    Args:
781      num_tasks: number of tasks.
782      job_name: String, a name for the parameter server job.
783      device_type: Optional device type string (e.g. "CPU" or "GPU")
784      device_index: int.  Optional device index.  If left
785        unspecified, device represents 'any' device_index.
786    """
787    self._job_name = job_name
788    self._device_type = device_type
789    self._device_index = device_index
790    self._replica = replica
791    self._num_tasks = num_tasks
792    self._next_task_id = 0
793
794  def __call__(self, op):
795    device_spec = tf_device.DeviceSpec(
796        replica=self._replica,
797        device_type=self._device_type,
798        device_index=self._device_index)
799    if self._num_tasks > 0:
800      task_id = self._next_task_id
801      self._next_task_id = (self._next_task_id + 1) % self._num_tasks
802      device_spec.job = self._job_name
803      device_spec.task = task_id
804    return device_spec.to_string()
805
806
807def filter_variables(var_list, include_patterns=None, exclude_patterns=None,
808                     reg_search=True):
809  """Filter a list of variables using regular expressions.
810
811  First includes variables according to the list of include_patterns.
812  Afterwards, eliminates variables according to the list of exclude_patterns.
813
814  For example, one can obtain a list of variables with the weights of all
815  convolutional layers (depending on the network definition) by:
816
817  ```python
818  variables = tf.contrib.framework.get_model_variables()
819  conv_weight_variables = tf.contrib.framework.filter_variables(
820      variables,
821      include_patterns=['Conv'],
822      exclude_patterns=['biases', 'Logits'])
823  ```
824
825  Args:
826    var_list: list of variables.
827    include_patterns: list of regular expressions to include. Defaults to None,
828        which means all variables are selected according to the include rules.
829        A variable is included if it matches any of the include_patterns.
830    exclude_patterns: list of regular expressions to exclude. Defaults to None,
831        which means all variables are selected according to the exclude rules.
832        A variable is excluded if it matches any of the exclude_patterns.
833    reg_search: boolean. If True (default), performs re.search to find matches
834        (i.e. pattern can match any substring of the variable name). If False,
835        performs re.match (i.e. regexp should match from the beginning of the
836        variable name).
837
838  Returns:
839    filtered list of variables.
840  """
841  if reg_search:
842    reg_exp_func = re.search
843  else:
844    reg_exp_func = re.match
845
846  # First include variables.
847  if include_patterns is None:
848    included_variables = list(var_list)
849  else:
850    included_variables = []
851    for var in var_list:
852      if any(reg_exp_func(ptrn, var.name) for ptrn in include_patterns):
853        included_variables.append(var)
854
855  # Afterwards, exclude variables.
856  if exclude_patterns is None:
857    filtered_variables = included_variables
858  else:
859    filtered_variables = []
860    for var in included_variables:
861      if not any(reg_exp_func(ptrn, var.name) for ptrn in exclude_patterns):
862        filtered_variables.append(var)
863
864  return filtered_variables
865