• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to warm-start TF.Learn Estimators."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import six
23
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import state_ops
26from tensorflow.python.ops import variable_scope
27from tensorflow.python.ops import variables as variables_lib
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.training import checkpoint_ops
30from tensorflow.python.training import checkpoint_utils
31from tensorflow.python.training.saving import saveable_object_util
32from tensorflow.python.util.tf_export import tf_export
33
34
35@tf_export(v1=["train.VocabInfo"])
36class VocabInfo(
37    collections.namedtuple("VocabInfo", [
38        "new_vocab",
39        "new_vocab_size",
40        "num_oov_buckets",
41        "old_vocab",
42        "old_vocab_size",
43        "backup_initializer",
44        "axis",
45    ])):
46  """Vocabulary information for warm-starting.
47
48  See `tf.estimator.WarmStartSettings` for examples of using
49  VocabInfo to warm-start.
50
51  Attributes:
52    new_vocab: [Required] A path to the new vocabulary file (used with the
53      model to be trained).
54    new_vocab_size: [Required] An integer indicating how many entries of the new
55      vocabulary will used in training.
56    num_oov_buckets: [Required] An integer indicating how many OOV buckets are
57      associated with the vocabulary.
58    old_vocab: [Required] A path to the old vocabulary file (used with the
59      checkpoint to be warm-started from).
60    old_vocab_size: [Optional] An integer indicating how many entries of the old
61      vocabulary were used in the creation of the checkpoint. If not provided,
62      the entire old vocabulary will be used.
63    backup_initializer: [Optional] A variable initializer used for variables
64      corresponding to new vocabulary entries and OOV. If not provided, these
65      entries will be zero-initialized.
66    axis: [Optional] Denotes what axis the vocabulary corresponds to.  The
67      default, 0, corresponds to the most common use case (embeddings or
68      linear weights for binary classification / regression).  An axis of 1
69      could be used for warm-starting output layers with class vocabularies.
70
71      For example:
72
73      embeddings_vocab_info = tf.VocabInfo(
74          new_vocab='embeddings_vocab',
75          new_vocab_size=100,
76          num_oov_buckets=1,
77          old_vocab='pretrained_embeddings_vocab',
78          old_vocab_size=10000,
79          backup_initializer=tf.truncated_normal_initializer(
80              mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
81          axis=0)
82
83      softmax_output_layer_kernel_vocab_info = tf.VocabInfo(
84          new_vocab='class_vocab',
85          new_vocab_size=5,
86          num_oov_buckets=0,  # No OOV for classes.
87          old_vocab='old_class_vocab',
88          old_vocab_size=8,
89          backup_initializer=tf.glorot_uniform_initializer(),
90          axis=1)
91
92      softmax_output_layer_bias_vocab_info = tf.VocabInfo(
93          new_vocab='class_vocab',
94          new_vocab_size=5,
95          num_oov_buckets=0,  # No OOV for classes.
96          old_vocab='old_class_vocab',
97          old_vocab_size=8,
98          backup_initializer=tf.zeros_initializer(),
99          axis=0)
100
101      Currently, only axis=0 and axis=1 are supported.
102  """
103
104  def __new__(cls,
105              new_vocab,
106              new_vocab_size,
107              num_oov_buckets,
108              old_vocab,
109              old_vocab_size=-1,
110              backup_initializer=None,
111              axis=0):
112    if axis != 0 and axis != 1:
113      raise ValueError("The only supported values for the axis argument are 0 "
114                       "and 1.  Provided axis: {}".format(axis))
115
116    return super(VocabInfo, cls).__new__(
117        cls,
118        new_vocab,
119        new_vocab_size,
120        num_oov_buckets,
121        old_vocab,
122        old_vocab_size,
123        backup_initializer,
124        axis,
125    )
126
127
128def _infer_var_name(var):
129  """Returns name of the `var`.
130
131  Args:
132    var: A list. The list can contain either of the following:
133      (i) A single `Variable`
134      (ii) A single `ResourceVariable`
135      (iii) Multiple `Variable` objects which must be slices of the same larger
136        variable.
137      (iv) A single `PartitionedVariable`
138
139  Returns:
140    Name of the `var`
141  """
142  name_to_var_dict = saveable_object_util.op_list_to_dict(var)
143  if len(name_to_var_dict) > 1:
144    raise TypeError("`var` = %s passed as arg violates the constraints.  "
145                    "name_to_var_dict = %s" % (var, name_to_var_dict))
146  return list(name_to_var_dict.keys())[0]
147
148
149def _get_var_info(var, prev_tensor_name=None):
150  """Helper method for standarizing Variable and naming.
151
152  Args:
153    var: Current graph's variable that needs to be warm-started (initialized).
154      Can be either of the following: (i) `Variable` (ii) `ResourceVariable`
155      (iii) list of `Variable`: The list must contain slices of the same larger
156        variable. (iv) `PartitionedVariable`
157    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
158      None, we lookup tensor with same name as given `var`.
159
160  Returns:
161    A tuple of the Tensor name and var.
162  """
163  if checkpoint_utils._is_variable(var):  # pylint: disable=protected-access
164    current_var_name = _infer_var_name([var])
165  elif (isinstance(var, list) and
166        all(checkpoint_utils._is_variable(v) for v in var)):  # pylint: disable=protected-access
167    current_var_name = _infer_var_name(var)
168  elif isinstance(var, variables_lib.PartitionedVariable):
169    current_var_name = _infer_var_name([var])
170    var = var._get_variable_list()  # pylint: disable=protected-access
171  else:
172    raise TypeError(
173        "var MUST be one of the following: a Variable, list of Variable or "
174        "PartitionedVariable, but is {}".format(type(var)))
175  if not prev_tensor_name:
176    # Assume tensor name remains the same.
177    prev_tensor_name = current_var_name
178
179  return prev_tensor_name, var
180
181
182# pylint: disable=protected-access
183# Accesses protected members of tf.Variable to reset the variable's internal
184# state.
185def _warm_start_var_with_vocab(var,
186                               current_vocab_path,
187                               current_vocab_size,
188                               prev_ckpt,
189                               prev_vocab_path,
190                               previous_vocab_size=-1,
191                               current_oov_buckets=0,
192                               prev_tensor_name=None,
193                               initializer=None,
194                               axis=0):
195  """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
196
197  Use this method when the `var` is backed by vocabulary. This method stitches
198  the given `var` such that values corresponding to individual features in the
199  vocabulary remain consistent irrespective of changing order of the features
200  between old and new vocabularies.
201
202  Args:
203    var: Current graph's variable that needs to be warm-started (initialized).
204      Can be either of the following:
205      (i) `Variable`
206      (ii) `ResourceVariable`
207      (iii) list of `Variable`: The list must contain slices of the same larger
208        variable.
209      (iv) `PartitionedVariable`
210    current_vocab_path: Path to the vocab file used for the given `var`.
211    current_vocab_size: An `int` specifying the number of entries in the current
212      vocab.
213    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
214      to checkpoint. The given checkpoint must have tensor with name
215      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
216    prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`.
217    previous_vocab_size: If provided, will constrain previous vocab to the first
218      `previous_vocab_size` entries.  -1 means use the entire previous vocab.
219    current_oov_buckets: An `int` specifying the number of out-of-vocabulary
220      buckets used for given `var`.
221    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
222      None, we lookup tensor with same name as given `var`.
223    initializer: Variable initializer to be used for missing entries.  If None,
224      missing entries will be zero-initialized.
225    axis: Axis of the variable that the provided vocabulary corresponds to.
226
227  Raises:
228    ValueError: If required args are not provided.
229  """
230  if not (current_vocab_path and current_vocab_size and prev_ckpt and
231          prev_vocab_path):
232    raise ValueError("Invalid args: Must provide all of [current_vocab_path, "
233                     "current_vocab_size, prev_ckpt, prev_vocab_path}.")
234  if checkpoint_utils._is_variable(var):
235    var = [var]
236  elif (isinstance(var, list) and
237        all(checkpoint_utils._is_variable(v) for v in var)):
238    var = var
239  elif isinstance(var, variables_lib.PartitionedVariable):
240    var = var._get_variable_list()
241  else:
242    raise TypeError(
243        "var MUST be one of the following: a Variable, list of Variable or "
244        "PartitionedVariable, but is {}".format(type(var)))
245
246  if not prev_tensor_name:
247    # Assume tensor name remains the same.
248    prev_tensor_name = _infer_var_name(var)
249
250  # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases).
251  total_v_first_axis = sum(v.get_shape().as_list()[0] for v in var)
252  for v in var:
253    v_shape = v.get_shape().as_list()
254    slice_info = v._get_save_slice_info()
255    partition_info = None
256    if slice_info:
257      partition_info = variable_scope._PartitionInfo(
258          full_shape=slice_info.full_shape,
259          var_offset=slice_info.var_offset)
260
261    if axis == 0:
262      new_row_vocab_size = current_vocab_size
263      new_col_vocab_size = v_shape[1]
264      old_row_vocab_size = previous_vocab_size
265      old_row_vocab_file = prev_vocab_path
266      new_row_vocab_file = current_vocab_path
267      old_col_vocab_file = None
268      new_col_vocab_file = None
269      num_row_oov_buckets = current_oov_buckets
270      num_col_oov_buckets = 0
271    elif axis == 1:
272      # Note that we must compute this value across all partitions, whereas
273      # in the axis = 0 case, we can simply use v_shape[1] because we don't
274      # allow partitioning across axis = 1.
275      new_row_vocab_size = total_v_first_axis
276      new_col_vocab_size = current_vocab_size
277      old_row_vocab_size = -1
278      old_row_vocab_file = None
279      new_row_vocab_file = None
280      old_col_vocab_file = prev_vocab_path
281      new_col_vocab_file = current_vocab_path
282      num_row_oov_buckets = 0
283      num_col_oov_buckets = current_oov_buckets
284    else:
285      raise ValueError("The only supported values for the axis argument are 0 "
286                       "and 1.  Provided axis: {}".format(axis))
287
288    init = checkpoint_ops._load_and_remap_matrix_initializer(
289        ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
290        old_tensor_name=prev_tensor_name,
291        new_row_vocab_size=new_row_vocab_size,
292        new_col_vocab_size=new_col_vocab_size,
293        old_row_vocab_size=old_row_vocab_size,
294        old_row_vocab_file=old_row_vocab_file,
295        new_row_vocab_file=new_row_vocab_file,
296        old_col_vocab_file=old_col_vocab_file,
297        new_col_vocab_file=new_col_vocab_file,
298        num_row_oov_buckets=num_row_oov_buckets,
299        num_col_oov_buckets=num_col_oov_buckets,
300        initializer=initializer)
301    new_init_val = ops.convert_to_tensor(
302        init(shape=v_shape, partition_info=partition_info))
303    v._initializer_op = state_ops.assign(v, new_init_val)
304# pylint: enable=protected-access
305
306
307def _get_grouped_variables(vars_to_warm_start):
308  """Collects and groups (possibly partitioned) variables into a dictionary.
309
310  The variables can be provided explicitly through vars_to_warm_start, or they
311  are retrieved from collections (see below).
312
313  Args:
314    vars_to_warm_start: One of the following:
315
316      - A regular expression (string) that captures which variables to
317        warm-start (see tf.get_collection).  This expression will only consider
318        variables in the TRAINABLE_VARIABLES collection.
319      - A list of Variables to warm-start.
320      - A list of strings, each representing a full variable name to warm-start.
321      - `None`, in which case only variables specified in
322        `var_name_to_vocab_info` will be warm-started.
323  Returns:
324    A dictionary mapping variable names (strings) to lists of Variables.
325  Raises:
326    ValueError: If vars_to_warm_start is not a string, `None`, a list of
327      `Variables`, or a list of strings.
328  """
329  if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None:
330    # Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
331    # everything (in TRAINABLE_VARIABLES) here.
332    list_of_vars = ops.get_collection(
333        ops.GraphKeys.TRAINABLE_VARIABLES,
334        scope=vars_to_warm_start)
335  elif isinstance(vars_to_warm_start, list):
336    if all(isinstance(v, str) for v in vars_to_warm_start):
337      list_of_vars = []
338      for v in vars_to_warm_start:
339        list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
340                                           scope=v)
341    elif all(checkpoint_utils._is_variable(v) for v in vars_to_warm_start):  # pylint: disable=protected-access
342      list_of_vars = vars_to_warm_start
343    else:
344      raise ValueError("If `vars_to_warm_start` is a list, it must be all "
345                       "`Variable` or all `str`.  Given types are {}".format(
346                           [type(v) for v in vars_to_warm_start]))
347  else:
348    raise ValueError("`vars_to_warm_start must be a `list` or `str`.  Given "
349                     "type is {}".format(type(vars_to_warm_start)))
350  # We have to deal with partitioned variables, since get_collection flattens
351  # out the list.
352  grouped_variables = {}
353  for v in list_of_vars:
354    if not isinstance(v, list):
355      var_name = _infer_var_name([v])
356    else:
357      var_name = _infer_var_name(v)
358    grouped_variables.setdefault(var_name, []).append(v)
359
360  return grouped_variables
361
362
363@tf_export(v1=["train.warm_start"])
364def warm_start(ckpt_to_initialize_from,
365               vars_to_warm_start=".*",
366               var_name_to_vocab_info=None,
367               var_name_to_prev_var_name=None):
368  """Warm-starts a model using the given settings.
369
370  If you are using a tf.estimator.Estimator, this will automatically be called
371  during training.
372
373  Args:
374    ckpt_to_initialize_from: [Required] A string specifying the directory with
375      checkpoint file(s) or path to checkpoint from which to warm-start the
376      model parameters.
377    vars_to_warm_start: [Optional] One of the following:
378
379      - A regular expression (string) that captures which variables to
380        warm-start (see tf.get_collection).  This expression will only consider
381        variables in the TRAINABLE_VARIABLES collection -- if you need to
382        warm-start non_TRAINABLE vars (such as optimizer accumulators or batch
383        norm statistics), please use the below option.
384      - A list of Variables to warm-start.  If you do not have access to the
385        `Variable` objects at the call site, please use the below option.
386      - A list of strings, each a regex scope provided to tf.get_collection with
387        GLOBAL_VARIABLES (please see tf.get_collection).  For backwards
388        compatibility reasons, this is separate from the single-string argument
389        type.
390      - `None`, in which case only variables specified in
391        `var_name_to_vocab_info` will be warm-started.
392
393      Defaults to `'.*'`, which warm-starts all variables in the
394      TRAINABLE_VARIABLES collection.  Note that this excludes variables such
395      as accumulators and moving statistics from batch norm.
396    var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
397      `tf.estimator.VocabInfo`. The variable names should be "full" variables,
398      not the names of the partitions.  If not explicitly provided, the variable
399      is assumed to have no (changes to) vocabulary.
400    var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
401      name of the previously-trained variable in `ckpt_to_initialize_from`. If
402      not explicitly provided, the name of the variable is assumed to be same
403      between previous checkpoint and current model.  Note that this has no
404      effect on the set of variables that is warm-started, and only controls
405      name mapping (use `vars_to_warm_start` for controlling what variables to
406      warm-start).
407  Raises:
408    ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
409      configuration for variable names that are not used.  This is to ensure
410      a stronger check for variable configuration than relying on users to
411      examine the logs.
412  """
413  if var_name_to_vocab_info is None:
414    var_name_to_vocab_info = {}
415  if var_name_to_prev_var_name is None:
416    var_name_to_prev_var_name = {}
417  logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,))
418  grouped_variables = _get_grouped_variables(vars_to_warm_start)
419
420  # Keep track of which var_names in var_name_to_prev_var_name and
421  # var_name_to_vocab_info have been used.  Err on the safer side by throwing an
422  # exception if any are unused by the end of the loop.  It is easy to misname
423  # a variable during this configuration, in which case without this check, we
424  # would fail to warm-start silently.
425  prev_var_name_used = set()
426  vocab_info_used = set()
427
428  # Group the vocabless vars into one call to init_from_checkpoint.
429  vocabless_vars = {}
430  for var_name, variable in six.iteritems(grouped_variables):
431    prev_var_name = var_name_to_prev_var_name.get(var_name)
432    if prev_var_name:
433      prev_var_name_used.add(var_name)
434    vocab_info = var_name_to_vocab_info.get(var_name)
435    if vocab_info:
436      vocab_info_used.add(var_name)
437      logging.info(
438          "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}"
439          " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}"
440          " initializer: {}".format(
441              var_name,
442              vocab_info.new_vocab,
443              vocab_info.new_vocab_size,
444              vocab_info.old_vocab,
445              (vocab_info.old_vocab_size if vocab_info.old_vocab_size > 0
446               else "All"),
447              vocab_info.num_oov_buckets,
448              prev_var_name or "Unchanged",
449              vocab_info.backup_initializer or "zero-initialized"))
450      _warm_start_var_with_vocab(
451          variable,
452          current_vocab_path=vocab_info.new_vocab,
453          current_vocab_size=vocab_info.new_vocab_size,
454          prev_ckpt=ckpt_to_initialize_from,
455          prev_vocab_path=vocab_info.old_vocab,
456          previous_vocab_size=vocab_info.old_vocab_size,
457          current_oov_buckets=vocab_info.num_oov_buckets,
458          prev_tensor_name=prev_var_name,
459          initializer=vocab_info.backup_initializer,
460          axis=vocab_info.axis)
461    else:
462      # For the special value of vars_to_warm_start = None,
463      # we only warm-start variables with explicitly specified vocabularies.
464      if vars_to_warm_start:
465        logging.info("Warm-starting variable: {}; prev_var_name: {}".format(
466            var_name, prev_var_name or "Unchanged"))
467        # Because we use a default empty list in grouped_variables, single
468        # unpartitioned variables will be lists here, which we rectify in order
469        # for init_from_checkpoint logic to work correctly.
470        if len(variable) == 1:
471          variable = variable[0]
472        prev_tensor_name, var = _get_var_info(variable, prev_var_name)
473        vocabless_vars[prev_tensor_name] = var
474
475  checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from, vocabless_vars)
476  prev_var_name_not_used = set(
477      var_name_to_prev_var_name.keys()) - prev_var_name_used
478  vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used
479
480  if prev_var_name_not_used:
481    raise ValueError(
482        "You provided the following variables in "
483        "var_name_to_prev_var_name that were not used: "
484        "{0}.  Perhaps you misspelled them?  Here is the list of viable "
485        "variable names: {1}".format(prev_var_name_not_used,
486                                     grouped_variables.keys()))
487  if vocab_info_not_used:
488    raise ValueError(
489        "You provided the following variables in "
490        "var_name_to_vocab_info that were not used: {0}. "
491        " Perhaps you misspelled them?  Here is the list of viable variable "
492        "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))
493