1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to warm-start TF.Learn Estimators.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import six 23 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import state_ops 26from tensorflow.python.ops import variable_scope 27from tensorflow.python.ops import variables as variables_lib 28from tensorflow.python.platform import tf_logging as logging 29from tensorflow.python.training import checkpoint_ops 30from tensorflow.python.training import checkpoint_utils 31from tensorflow.python.training.saving import saveable_object_util 32from tensorflow.python.util.tf_export import tf_export 33 34 35@tf_export(v1=["train.VocabInfo"]) 36class VocabInfo( 37 collections.namedtuple("VocabInfo", [ 38 "new_vocab", 39 "new_vocab_size", 40 "num_oov_buckets", 41 "old_vocab", 42 "old_vocab_size", 43 "backup_initializer", 44 "axis", 45 ])): 46 """Vocabulary information for warm-starting. 47 48 See `tf.estimator.WarmStartSettings` for examples of using 49 VocabInfo to warm-start. 50 51 Attributes: 52 new_vocab: [Required] A path to the new vocabulary file (used with the 53 model to be trained). 54 new_vocab_size: [Required] An integer indicating how many entries of the new 55 vocabulary will used in training. 56 num_oov_buckets: [Required] An integer indicating how many OOV buckets are 57 associated with the vocabulary. 58 old_vocab: [Required] A path to the old vocabulary file (used with the 59 checkpoint to be warm-started from). 60 old_vocab_size: [Optional] An integer indicating how many entries of the old 61 vocabulary were used in the creation of the checkpoint. If not provided, 62 the entire old vocabulary will be used. 63 backup_initializer: [Optional] A variable initializer used for variables 64 corresponding to new vocabulary entries and OOV. If not provided, these 65 entries will be zero-initialized. 66 axis: [Optional] Denotes what axis the vocabulary corresponds to. The 67 default, 0, corresponds to the most common use case (embeddings or 68 linear weights for binary classification / regression). An axis of 1 69 could be used for warm-starting output layers with class vocabularies. 70 71 For example: 72 73 embeddings_vocab_info = tf.VocabInfo( 74 new_vocab='embeddings_vocab', 75 new_vocab_size=100, 76 num_oov_buckets=1, 77 old_vocab='pretrained_embeddings_vocab', 78 old_vocab_size=10000, 79 backup_initializer=tf.truncated_normal_initializer( 80 mean=0.0, stddev=(1 / math.sqrt(embedding_dim))), 81 axis=0) 82 83 softmax_output_layer_kernel_vocab_info = tf.VocabInfo( 84 new_vocab='class_vocab', 85 new_vocab_size=5, 86 num_oov_buckets=0, # No OOV for classes. 87 old_vocab='old_class_vocab', 88 old_vocab_size=8, 89 backup_initializer=tf.glorot_uniform_initializer(), 90 axis=1) 91 92 softmax_output_layer_bias_vocab_info = tf.VocabInfo( 93 new_vocab='class_vocab', 94 new_vocab_size=5, 95 num_oov_buckets=0, # No OOV for classes. 96 old_vocab='old_class_vocab', 97 old_vocab_size=8, 98 backup_initializer=tf.zeros_initializer(), 99 axis=0) 100 101 Currently, only axis=0 and axis=1 are supported. 102 """ 103 104 def __new__(cls, 105 new_vocab, 106 new_vocab_size, 107 num_oov_buckets, 108 old_vocab, 109 old_vocab_size=-1, 110 backup_initializer=None, 111 axis=0): 112 if axis != 0 and axis != 1: 113 raise ValueError("The only supported values for the axis argument are 0 " 114 "and 1. Provided axis: {}".format(axis)) 115 116 return super(VocabInfo, cls).__new__( 117 cls, 118 new_vocab, 119 new_vocab_size, 120 num_oov_buckets, 121 old_vocab, 122 old_vocab_size, 123 backup_initializer, 124 axis, 125 ) 126 127 128def _infer_var_name(var): 129 """Returns name of the `var`. 130 131 Args: 132 var: A list. The list can contain either of the following: 133 (i) A single `Variable` 134 (ii) A single `ResourceVariable` 135 (iii) Multiple `Variable` objects which must be slices of the same larger 136 variable. 137 (iv) A single `PartitionedVariable` 138 139 Returns: 140 Name of the `var` 141 """ 142 name_to_var_dict = saveable_object_util.op_list_to_dict(var) 143 if len(name_to_var_dict) > 1: 144 raise TypeError("`var` = %s passed as arg violates the constraints. " 145 "name_to_var_dict = %s" % (var, name_to_var_dict)) 146 return list(name_to_var_dict.keys())[0] 147 148 149def _get_var_info(var, prev_tensor_name=None): 150 """Helper method for standarizing Variable and naming. 151 152 Args: 153 var: Current graph's variable that needs to be warm-started (initialized). 154 Can be either of the following: (i) `Variable` (ii) `ResourceVariable` 155 (iii) list of `Variable`: The list must contain slices of the same larger 156 variable. (iv) `PartitionedVariable` 157 prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If 158 None, we lookup tensor with same name as given `var`. 159 160 Returns: 161 A tuple of the Tensor name and var. 162 """ 163 if checkpoint_utils._is_variable(var): # pylint: disable=protected-access 164 current_var_name = _infer_var_name([var]) 165 elif (isinstance(var, list) and 166 all(checkpoint_utils._is_variable(v) for v in var)): # pylint: disable=protected-access 167 current_var_name = _infer_var_name(var) 168 elif isinstance(var, variables_lib.PartitionedVariable): 169 current_var_name = _infer_var_name([var]) 170 var = var._get_variable_list() # pylint: disable=protected-access 171 else: 172 raise TypeError( 173 "var MUST be one of the following: a Variable, list of Variable or " 174 "PartitionedVariable, but is {}".format(type(var))) 175 if not prev_tensor_name: 176 # Assume tensor name remains the same. 177 prev_tensor_name = current_var_name 178 179 return prev_tensor_name, var 180 181 182# pylint: disable=protected-access 183# Accesses protected members of tf.Variable to reset the variable's internal 184# state. 185def _warm_start_var_with_vocab(var, 186 current_vocab_path, 187 current_vocab_size, 188 prev_ckpt, 189 prev_vocab_path, 190 previous_vocab_size=-1, 191 current_oov_buckets=0, 192 prev_tensor_name=None, 193 initializer=None, 194 axis=0): 195 """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`. 196 197 Use this method when the `var` is backed by vocabulary. This method stitches 198 the given `var` such that values corresponding to individual features in the 199 vocabulary remain consistent irrespective of changing order of the features 200 between old and new vocabularies. 201 202 Args: 203 var: Current graph's variable that needs to be warm-started (initialized). 204 Can be either of the following: 205 (i) `Variable` 206 (ii) `ResourceVariable` 207 (iii) list of `Variable`: The list must contain slices of the same larger 208 variable. 209 (iv) `PartitionedVariable` 210 current_vocab_path: Path to the vocab file used for the given `var`. 211 current_vocab_size: An `int` specifying the number of entries in the current 212 vocab. 213 prev_ckpt: A string specifying the directory with checkpoint file(s) or path 214 to checkpoint. The given checkpoint must have tensor with name 215 `prev_tensor_name` (if not None) or tensor with name same as given `var`. 216 prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`. 217 previous_vocab_size: If provided, will constrain previous vocab to the first 218 `previous_vocab_size` entries. -1 means use the entire previous vocab. 219 current_oov_buckets: An `int` specifying the number of out-of-vocabulary 220 buckets used for given `var`. 221 prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If 222 None, we lookup tensor with same name as given `var`. 223 initializer: Variable initializer to be used for missing entries. If None, 224 missing entries will be zero-initialized. 225 axis: Axis of the variable that the provided vocabulary corresponds to. 226 227 Raises: 228 ValueError: If required args are not provided. 229 """ 230 if not (current_vocab_path and current_vocab_size and prev_ckpt and 231 prev_vocab_path): 232 raise ValueError("Invalid args: Must provide all of [current_vocab_path, " 233 "current_vocab_size, prev_ckpt, prev_vocab_path}.") 234 if checkpoint_utils._is_variable(var): 235 var = [var] 236 elif (isinstance(var, list) and 237 all(checkpoint_utils._is_variable(v) for v in var)): 238 var = var 239 elif isinstance(var, variables_lib.PartitionedVariable): 240 var = var._get_variable_list() 241 else: 242 raise TypeError( 243 "var MUST be one of the following: a Variable, list of Variable or " 244 "PartitionedVariable, but is {}".format(type(var))) 245 246 if not prev_tensor_name: 247 # Assume tensor name remains the same. 248 prev_tensor_name = _infer_var_name(var) 249 250 # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases). 251 total_v_first_axis = sum(v.get_shape().as_list()[0] for v in var) 252 for v in var: 253 v_shape = v.get_shape().as_list() 254 slice_info = v._get_save_slice_info() 255 partition_info = None 256 if slice_info: 257 partition_info = variable_scope._PartitionInfo( 258 full_shape=slice_info.full_shape, 259 var_offset=slice_info.var_offset) 260 261 if axis == 0: 262 new_row_vocab_size = current_vocab_size 263 new_col_vocab_size = v_shape[1] 264 old_row_vocab_size = previous_vocab_size 265 old_row_vocab_file = prev_vocab_path 266 new_row_vocab_file = current_vocab_path 267 old_col_vocab_file = None 268 new_col_vocab_file = None 269 num_row_oov_buckets = current_oov_buckets 270 num_col_oov_buckets = 0 271 elif axis == 1: 272 # Note that we must compute this value across all partitions, whereas 273 # in the axis = 0 case, we can simply use v_shape[1] because we don't 274 # allow partitioning across axis = 1. 275 new_row_vocab_size = total_v_first_axis 276 new_col_vocab_size = current_vocab_size 277 old_row_vocab_size = -1 278 old_row_vocab_file = None 279 new_row_vocab_file = None 280 old_col_vocab_file = prev_vocab_path 281 new_col_vocab_file = current_vocab_path 282 num_row_oov_buckets = 0 283 num_col_oov_buckets = current_oov_buckets 284 else: 285 raise ValueError("The only supported values for the axis argument are 0 " 286 "and 1. Provided axis: {}".format(axis)) 287 288 init = checkpoint_ops._load_and_remap_matrix_initializer( 289 ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt), 290 old_tensor_name=prev_tensor_name, 291 new_row_vocab_size=new_row_vocab_size, 292 new_col_vocab_size=new_col_vocab_size, 293 old_row_vocab_size=old_row_vocab_size, 294 old_row_vocab_file=old_row_vocab_file, 295 new_row_vocab_file=new_row_vocab_file, 296 old_col_vocab_file=old_col_vocab_file, 297 new_col_vocab_file=new_col_vocab_file, 298 num_row_oov_buckets=num_row_oov_buckets, 299 num_col_oov_buckets=num_col_oov_buckets, 300 initializer=initializer) 301 new_init_val = ops.convert_to_tensor( 302 init(shape=v_shape, partition_info=partition_info)) 303 v._initializer_op = state_ops.assign(v, new_init_val) 304# pylint: enable=protected-access 305 306 307def _get_grouped_variables(vars_to_warm_start): 308 """Collects and groups (possibly partitioned) variables into a dictionary. 309 310 The variables can be provided explicitly through vars_to_warm_start, or they 311 are retrieved from collections (see below). 312 313 Args: 314 vars_to_warm_start: One of the following: 315 316 - A regular expression (string) that captures which variables to 317 warm-start (see tf.get_collection). This expression will only consider 318 variables in the TRAINABLE_VARIABLES collection. 319 - A list of Variables to warm-start. 320 - A list of strings, each representing a full variable name to warm-start. 321 - `None`, in which case only variables specified in 322 `var_name_to_vocab_info` will be warm-started. 323 Returns: 324 A dictionary mapping variable names (strings) to lists of Variables. 325 Raises: 326 ValueError: If vars_to_warm_start is not a string, `None`, a list of 327 `Variables`, or a list of strings. 328 """ 329 if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None: 330 # Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match 331 # everything (in TRAINABLE_VARIABLES) here. 332 list_of_vars = ops.get_collection( 333 ops.GraphKeys.TRAINABLE_VARIABLES, 334 scope=vars_to_warm_start) 335 elif isinstance(vars_to_warm_start, list): 336 if all(isinstance(v, str) for v in vars_to_warm_start): 337 list_of_vars = [] 338 for v in vars_to_warm_start: 339 list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, 340 scope=v) 341 elif all(checkpoint_utils._is_variable(v) for v in vars_to_warm_start): # pylint: disable=protected-access 342 list_of_vars = vars_to_warm_start 343 else: 344 raise ValueError("If `vars_to_warm_start` is a list, it must be all " 345 "`Variable` or all `str`. Given types are {}".format( 346 [type(v) for v in vars_to_warm_start])) 347 else: 348 raise ValueError("`vars_to_warm_start must be a `list` or `str`. Given " 349 "type is {}".format(type(vars_to_warm_start))) 350 # We have to deal with partitioned variables, since get_collection flattens 351 # out the list. 352 grouped_variables = {} 353 for v in list_of_vars: 354 if not isinstance(v, list): 355 var_name = _infer_var_name([v]) 356 else: 357 var_name = _infer_var_name(v) 358 grouped_variables.setdefault(var_name, []).append(v) 359 360 return grouped_variables 361 362 363@tf_export(v1=["train.warm_start"]) 364def warm_start(ckpt_to_initialize_from, 365 vars_to_warm_start=".*", 366 var_name_to_vocab_info=None, 367 var_name_to_prev_var_name=None): 368 """Warm-starts a model using the given settings. 369 370 If you are using a tf.estimator.Estimator, this will automatically be called 371 during training. 372 373 Args: 374 ckpt_to_initialize_from: [Required] A string specifying the directory with 375 checkpoint file(s) or path to checkpoint from which to warm-start the 376 model parameters. 377 vars_to_warm_start: [Optional] One of the following: 378 379 - A regular expression (string) that captures which variables to 380 warm-start (see tf.get_collection). This expression will only consider 381 variables in the TRAINABLE_VARIABLES collection -- if you need to 382 warm-start non_TRAINABLE vars (such as optimizer accumulators or batch 383 norm statistics), please use the below option. 384 - A list of Variables to warm-start. If you do not have access to the 385 `Variable` objects at the call site, please use the below option. 386 - A list of strings, each a regex scope provided to tf.get_collection with 387 GLOBAL_VARIABLES (please see tf.get_collection). For backwards 388 compatibility reasons, this is separate from the single-string argument 389 type. 390 - `None`, in which case only variables specified in 391 `var_name_to_vocab_info` will be warm-started. 392 393 Defaults to `'.*'`, which warm-starts all variables in the 394 TRAINABLE_VARIABLES collection. Note that this excludes variables such 395 as accumulators and moving statistics from batch norm. 396 var_name_to_vocab_info: [Optional] Dict of variable names (strings) to 397 `tf.estimator.VocabInfo`. The variable names should be "full" variables, 398 not the names of the partitions. If not explicitly provided, the variable 399 is assumed to have no (changes to) vocabulary. 400 var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to 401 name of the previously-trained variable in `ckpt_to_initialize_from`. If 402 not explicitly provided, the name of the variable is assumed to be same 403 between previous checkpoint and current model. Note that this has no 404 effect on the set of variables that is warm-started, and only controls 405 name mapping (use `vars_to_warm_start` for controlling what variables to 406 warm-start). 407 Raises: 408 ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo 409 configuration for variable names that are not used. This is to ensure 410 a stronger check for variable configuration than relying on users to 411 examine the logs. 412 """ 413 if var_name_to_vocab_info is None: 414 var_name_to_vocab_info = {} 415 if var_name_to_prev_var_name is None: 416 var_name_to_prev_var_name = {} 417 logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,)) 418 grouped_variables = _get_grouped_variables(vars_to_warm_start) 419 420 # Keep track of which var_names in var_name_to_prev_var_name and 421 # var_name_to_vocab_info have been used. Err on the safer side by throwing an 422 # exception if any are unused by the end of the loop. It is easy to misname 423 # a variable during this configuration, in which case without this check, we 424 # would fail to warm-start silently. 425 prev_var_name_used = set() 426 vocab_info_used = set() 427 428 # Group the vocabless vars into one call to init_from_checkpoint. 429 vocabless_vars = {} 430 for var_name, variable in six.iteritems(grouped_variables): 431 prev_var_name = var_name_to_prev_var_name.get(var_name) 432 if prev_var_name: 433 prev_var_name_used.add(var_name) 434 vocab_info = var_name_to_vocab_info.get(var_name) 435 if vocab_info: 436 vocab_info_used.add(var_name) 437 logging.info( 438 "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}" 439 " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}" 440 " initializer: {}".format( 441 var_name, 442 vocab_info.new_vocab, 443 vocab_info.new_vocab_size, 444 vocab_info.old_vocab, 445 (vocab_info.old_vocab_size if vocab_info.old_vocab_size > 0 446 else "All"), 447 vocab_info.num_oov_buckets, 448 prev_var_name or "Unchanged", 449 vocab_info.backup_initializer or "zero-initialized")) 450 _warm_start_var_with_vocab( 451 variable, 452 current_vocab_path=vocab_info.new_vocab, 453 current_vocab_size=vocab_info.new_vocab_size, 454 prev_ckpt=ckpt_to_initialize_from, 455 prev_vocab_path=vocab_info.old_vocab, 456 previous_vocab_size=vocab_info.old_vocab_size, 457 current_oov_buckets=vocab_info.num_oov_buckets, 458 prev_tensor_name=prev_var_name, 459 initializer=vocab_info.backup_initializer, 460 axis=vocab_info.axis) 461 else: 462 # For the special value of vars_to_warm_start = None, 463 # we only warm-start variables with explicitly specified vocabularies. 464 if vars_to_warm_start: 465 logging.info("Warm-starting variable: {}; prev_var_name: {}".format( 466 var_name, prev_var_name or "Unchanged")) 467 # Because we use a default empty list in grouped_variables, single 468 # unpartitioned variables will be lists here, which we rectify in order 469 # for init_from_checkpoint logic to work correctly. 470 if len(variable) == 1: 471 variable = variable[0] 472 prev_tensor_name, var = _get_var_info(variable, prev_var_name) 473 vocabless_vars[prev_tensor_name] = var 474 475 checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from, vocabless_vars) 476 prev_var_name_not_used = set( 477 var_name_to_prev_var_name.keys()) - prev_var_name_used 478 vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used 479 480 if prev_var_name_not_used: 481 raise ValueError( 482 "You provided the following variables in " 483 "var_name_to_prev_var_name that were not used: " 484 "{0}. Perhaps you misspelled them? Here is the list of viable " 485 "variable names: {1}".format(prev_var_name_not_used, 486 grouped_variables.keys())) 487 if vocab_info_not_used: 488 raise ValueError( 489 "You provided the following variables in " 490 "var_name_to_vocab_info that were not used: {0}. " 491 " Perhaps you misspelled them? Here is the list of viable variable " 492 "names: {1}".format(vocab_info_not_used, grouped_variables.keys())) 493