1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for warm_starting_util.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import numpy as np 23import six 24 25from tensorflow.python.feature_column import feature_column_lib as fc 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import init_ops 30from tensorflow.python.ops import variable_scope 31from tensorflow.python.ops import variables 32from tensorflow.python.platform import test 33from tensorflow.python.training import checkpoint_utils 34from tensorflow.python.training import saver as saver_lib 35from tensorflow.python.training import warm_starting_util as ws_util 36 37ones = init_ops.ones_initializer 38norms = init_ops.truncated_normal_initializer 39rand = init_ops.random_uniform_initializer 40zeros = init_ops.zeros_initializer 41 42 43class WarmStartingUtilTest(test.TestCase): 44 45 def _write_vocab(self, string_values, file_name): 46 vocab_file = os.path.join(self.get_temp_dir(), file_name) 47 with open(vocab_file, "w") as f: 48 f.write("\n".join(string_values)) 49 return vocab_file 50 51 def _write_checkpoint(self, sess): 52 self.evaluate(variables.global_variables_initializer()) 53 saver = saver_lib.Saver() 54 ckpt_prefix = os.path.join(self.get_temp_dir(), "model") 55 saver.save(sess, ckpt_prefix, global_step=0) 56 57 def _create_prev_run_var(self, 58 var_name, 59 shape=None, 60 initializer=None, 61 partitioner=None): 62 with ops.Graph().as_default() as g: 63 with self.session(graph=g) as sess: 64 var = variable_scope.get_variable( 65 var_name, 66 shape=shape, 67 initializer=initializer, 68 partitioner=partitioner) 69 self._write_checkpoint(sess) 70 if partitioner: 71 self.assertTrue(isinstance(var, variables.PartitionedVariable)) 72 var = var._get_variable_list() 73 return var, self.evaluate(var) 74 75 def _create_prev_run_vars(self, 76 var_names, 77 shapes, 78 initializers): 79 with ops.Graph().as_default() as g: 80 with self.session(graph=g) as sess: 81 all_vars = [] 82 for var_name, shape, initializer in zip(var_names, shapes, 83 initializers): 84 all_vars.append(variable_scope.get_variable( 85 var_name, 86 shape=shape, 87 initializer=initializer)) 88 self._write_checkpoint(sess) 89 return [self.evaluate(var) for var in all_vars] 90 91 def _create_dummy_inputs(self): 92 return { 93 "sc_int": array_ops.sparse_placeholder(dtypes.int32), 94 "sc_hash": array_ops.sparse_placeholder(dtypes.string), 95 "sc_keys": array_ops.sparse_placeholder(dtypes.string), 96 "sc_vocab": array_ops.sparse_placeholder(dtypes.string), 97 "real": array_ops.placeholder(dtypes.float32) 98 } 99 100 def _create_linear_model(self, feature_cols, partitioner): 101 cols_to_vars = {} 102 with variable_scope.variable_scope("", partitioner=partitioner): 103 # Create the variables. 104 fc.linear_model( 105 features=self._create_dummy_inputs(), 106 feature_columns=feature_cols, 107 units=1, 108 cols_to_vars=cols_to_vars) 109 # Return a dictionary mapping each column to its variable. 110 return cols_to_vars 111 112 def _assert_cols_to_vars(self, cols_to_vars, cols_to_expected_values, sess): 113 for col, expected_values in six.iteritems(cols_to_expected_values): 114 for i, var in enumerate(cols_to_vars[col]): 115 self.assertAllClose(expected_values[i], var.eval(sess)) 116 117 def testWarmStartVar(self): 118 _, prev_val = self._create_prev_run_var( 119 "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]]) 120 121 with ops.Graph().as_default() as g: 122 with self.session(graph=g) as sess: 123 fruit_weights = variable_scope.get_variable( 124 "fruit_weights", initializer=[[0.], [0.], [0.], [0.]]) 125 prev_tensor_name, var = ws_util._get_var_info(fruit_weights) 126 checkpoint_utils.init_from_checkpoint(self.get_temp_dir(), 127 {prev_tensor_name: var}) 128 self.evaluate(variables.global_variables_initializer()) 129 self.assertAllClose(prev_val, fruit_weights.eval(sess)) 130 131 def testWarmStartVarPrevVarPartitioned(self): 132 _, weights = self._create_prev_run_var( 133 "fruit_weights", 134 shape=[4, 1], 135 initializer=[[0.5], [1.], [1.5], [2.]], 136 partitioner=lambda shape, dtype: [2, 1]) 137 prev_val = np.concatenate([weights[0], weights[1]], axis=0) 138 139 with ops.Graph().as_default() as g: 140 with self.session(graph=g) as sess: 141 fruit_weights = variable_scope.get_variable( 142 "fruit_weights", initializer=[[0.], [0.], [0.], [0.]]) 143 prev_tensor_name, var = ws_util._get_var_info(fruit_weights) 144 checkpoint_utils.init_from_checkpoint(self.get_temp_dir(), 145 {prev_tensor_name: var}) 146 self.evaluate(variables.global_variables_initializer()) 147 self.assertAllClose(prev_val, fruit_weights.eval(sess)) 148 149 def testWarmStartVarCurrentVarPartitioned(self): 150 _, prev_val = self._create_prev_run_var( 151 "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]]) 152 153 with ops.Graph().as_default() as g: 154 with self.session(graph=g) as sess: 155 fruit_weights = variable_scope.get_variable( 156 "fruit_weights", 157 shape=[4, 1], 158 initializer=[[0.], [0.], [0.], [0.]], 159 partitioner=lambda shape, dtype: [2, 1]) 160 self.assertTrue( 161 isinstance(fruit_weights, variables.PartitionedVariable)) 162 prev_tensor_name, var = ws_util._get_var_info(fruit_weights) 163 checkpoint_utils.init_from_checkpoint(self.get_temp_dir(), 164 {prev_tensor_name: var}) 165 self.evaluate(variables.global_variables_initializer()) 166 fruit_weights = fruit_weights._get_variable_list() 167 new_val = np.concatenate( 168 [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0) 169 self.assertAllClose(prev_val, new_val) 170 171 def testWarmStartVarBothVarsPartitioned(self): 172 _, weights = self._create_prev_run_var( 173 "old_scope/fruit_weights", 174 shape=[4, 1], 175 initializer=[[0.5], [1.], [1.5], [2.]], 176 partitioner=lambda shape, dtype: [2, 1]) 177 prev_val = np.concatenate([weights[0], weights[1]], axis=0) 178 # New session and new graph. 179 with ops.Graph().as_default() as g: 180 with self.session(graph=g) as sess: 181 fruit_weights = variable_scope.get_variable( 182 "new_scope/fruit_weights", 183 shape=[4, 1], 184 initializer=[[0.], [0.], [0.], [0.]], 185 partitioner=lambda shape, dtype: [2, 1]) 186 self.assertTrue( 187 isinstance(fruit_weights, variables.PartitionedVariable)) 188 prev_tensor_name, var = ws_util._get_var_info( 189 fruit_weights, prev_tensor_name="old_scope/fruit_weights") 190 checkpoint_utils.init_from_checkpoint(self.get_temp_dir(), 191 {prev_tensor_name: var}) 192 self.evaluate(variables.global_variables_initializer()) 193 fruit_weights = fruit_weights._get_variable_list() 194 new_val = np.concatenate( 195 [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0) 196 self.assertAllClose(prev_val, new_val) 197 198 def testWarmStartVarWithVocab(self): 199 prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], 200 "old_vocab") 201 self._create_prev_run_var( 202 "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]]) 203 204 # New vocab with elements in reverse order and one new element. 205 new_vocab_path = self._write_vocab( 206 ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab") 207 # New session and new graph. 208 with ops.Graph().as_default() as g: 209 with self.session(graph=g) as sess: 210 fruit_weights = variable_scope.get_variable( 211 "fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]]) 212 ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5, 213 self.get_temp_dir(), prev_vocab_path) 214 self.evaluate(variables.global_variables_initializer()) 215 self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]], 216 fruit_weights.eval(sess)) 217 218 def testWarmStartVarWithColumnVocab(self): 219 prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") 220 self._create_prev_run_var( 221 "fruit_output_layer", 222 initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]]) 223 224 # New vocab with elements in reverse order and one new element. 225 new_vocab_path = self._write_vocab(["orange", "apple", "banana"], 226 "new_vocab") 227 # New session and new graph. 228 with ops.Graph().as_default() as g: 229 with self.session(graph=g) as sess: 230 fruit_output_layer = variable_scope.get_variable( 231 "fruit_output_layer", 232 initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], 233 [0., 0., 0.]]) 234 ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, 235 current_vocab_size=3, 236 prev_ckpt=self.get_temp_dir(), 237 prev_vocab_path=prev_vocab_path, 238 axis=1) 239 self.evaluate(variables.global_variables_initializer()) 240 self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.], 241 [2.3, 2., 0.]], fruit_output_layer.eval(sess)) 242 243 def testWarmStartVarWithVocabConstrainedOldVocabSize(self): 244 prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], 245 "old_vocab") 246 self._create_prev_run_var( 247 "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]]) 248 249 # New vocab with elements in reverse order and one new element. 250 new_vocab_path = self._write_vocab( 251 ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab") 252 # New session and new graph. 253 with ops.Graph().as_default() as g: 254 with self.session(graph=g) as sess: 255 fruit_weights = variable_scope.get_variable( 256 "fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]]) 257 ws_util._warm_start_var_with_vocab( 258 fruit_weights, 259 new_vocab_path, 260 5, 261 self.get_temp_dir(), 262 prev_vocab_path, 263 previous_vocab_size=2) 264 self.evaluate(variables.global_variables_initializer()) 265 # Old vocabulary limited to ['apple', 'banana']. 266 self.assertAllClose([[0.], [0.], [1.], [0.5], [0.]], 267 fruit_weights.eval(sess)) 268 269 def testWarmStartVarWithVocabPrevVarPartitioned(self): 270 prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], 271 "old_vocab") 272 self._create_prev_run_var( 273 "fruit_weights", 274 shape=[4, 1], 275 initializer=[[0.5], [1.], [1.5], [2.]], 276 partitioner=lambda shape, dtype: [2, 1]) 277 278 # New vocab with elements in reverse order and one new element. 279 new_vocab_path = self._write_vocab( 280 ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab") 281 # New session and new graph. 282 with ops.Graph().as_default() as g: 283 with self.session(graph=g) as sess: 284 fruit_weights = variable_scope.get_variable( 285 "fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]]) 286 ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5, 287 self.get_temp_dir(), prev_vocab_path) 288 self.evaluate(variables.global_variables_initializer()) 289 self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]], 290 fruit_weights.eval(sess)) 291 292 def testWarmStartVarWithColumnVocabPrevVarPartitioned(self): 293 prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") 294 self._create_prev_run_var( 295 "fruit_output_layer", 296 shape=[4, 2], 297 initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]], 298 partitioner=lambda shape, dtype: [2, 1]) 299 300 # New vocab with elements in reverse order and one new element. 301 new_vocab_path = self._write_vocab(["orange", "apple", "banana"], 302 "new_vocab") 303 # New session and new graph. 304 with ops.Graph().as_default() as g: 305 with self.session(graph=g) as sess: 306 fruit_output_layer = variable_scope.get_variable( 307 "fruit_output_layer", 308 initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], 309 [0., 0., 0.]]) 310 ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, 311 current_vocab_size=3, 312 prev_ckpt=self.get_temp_dir(), 313 prev_vocab_path=prev_vocab_path, 314 axis=1) 315 self.evaluate(variables.global_variables_initializer()) 316 self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.], 317 [2.3, 2., 0.]], fruit_output_layer.eval(sess)) 318 319 def testWarmStartVarWithVocabCurrentVarPartitioned(self): 320 prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], 321 "old_vocab") 322 self._create_prev_run_var( 323 "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]]) 324 325 # New vocab with elements in reverse order and one new element. 326 new_vocab_path = self._write_vocab( 327 ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab") 328 # New session and new graph. 329 with ops.Graph().as_default() as g: 330 with self.session(graph=g) as sess: 331 fruit_weights = variable_scope.get_variable( 332 "fruit_weights", 333 shape=[6, 1], 334 initializer=[[0.], [0.], [0.], [0.], [0.], [0.]], 335 partitioner=lambda shape, dtype: [2, 1]) 336 ws_util._warm_start_var_with_vocab( 337 fruit_weights, 338 new_vocab_path, 339 5, 340 self.get_temp_dir(), 341 prev_vocab_path, 342 current_oov_buckets=1) 343 self.evaluate(variables.global_variables_initializer()) 344 self.assertTrue( 345 isinstance(fruit_weights, variables.PartitionedVariable)) 346 fruit_weights_vars = fruit_weights._get_variable_list() 347 self.assertAllClose([[2.], [1.5], [1.]], 348 fruit_weights_vars[0].eval(sess)) 349 self.assertAllClose([[0.5], [0.], [0.]], 350 fruit_weights_vars[1].eval(sess)) 351 352 def testWarmStartVarWithColumnVocabCurrentVarPartitioned(self): 353 prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") 354 self._create_prev_run_var( 355 "fruit_output_layer", 356 initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]]) 357 358 # New vocab with elements in reverse order and one new element. 359 new_vocab_path = self._write_vocab(["orange", "apple", "banana"], 360 "new_vocab") 361 # New session and new graph. 362 with ops.Graph().as_default() as g: 363 with self.session(graph=g) as sess: 364 fruit_output_layer = variable_scope.get_variable( 365 "fruit_output_layer", 366 shape=[4, 3], 367 initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], 368 [0., 0., 0.]], 369 partitioner=lambda shape, dtype: [2, 1]) 370 ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, 371 current_vocab_size=3, 372 prev_ckpt=self.get_temp_dir(), 373 prev_vocab_path=prev_vocab_path, 374 axis=1) 375 self.evaluate(variables.global_variables_initializer()) 376 self.assertTrue( 377 isinstance(fruit_output_layer, variables.PartitionedVariable)) 378 fruit_output_layer_vars = fruit_output_layer._get_variable_list() 379 self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]], 380 fruit_output_layer_vars[0].eval(sess)) 381 self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]], 382 fruit_output_layer_vars[1].eval(sess)) 383 384 def testWarmStartVarWithVocabBothVarsPartitioned(self): 385 prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], 386 "old_vocab") 387 self._create_prev_run_var( 388 "fruit_weights", 389 shape=[4, 1], 390 initializer=[[0.5], [1.], [1.5], [2.]], 391 partitioner=lambda shape, dtype: [2, 1]) 392 393 # New vocab with elements in reverse order and two new elements. 394 new_vocab_path = self._write_vocab( 395 ["orange", "guava", "banana", "apple", "raspberry", 396 "blueberry"], "new_vocab") 397 # New session and new graph. 398 with ops.Graph().as_default() as g: 399 with self.session(graph=g) as sess: 400 fruit_weights = variable_scope.get_variable( 401 "fruit_weights", 402 shape=[6, 1], 403 initializer=[[0.], [0.], [0.], [0.], [0.], [0.]], 404 partitioner=lambda shape, dtype: [2, 1]) 405 ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 6, 406 self.get_temp_dir(), prev_vocab_path) 407 self.evaluate(variables.global_variables_initializer()) 408 self.assertTrue( 409 isinstance(fruit_weights, variables.PartitionedVariable)) 410 fruit_weights_vars = fruit_weights._get_variable_list() 411 self.assertAllClose([[2.], [1.5], [1.]], 412 fruit_weights_vars[0].eval(sess)) 413 self.assertAllClose([[0.5], [0.], [0.]], 414 fruit_weights_vars[1].eval(sess)) 415 416 def testWarmStartVarWithColumnVocabBothVarsPartitioned(self): 417 prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") 418 self._create_prev_run_var( 419 "fruit_output_layer", 420 shape=[4, 2], 421 initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]], 422 partitioner=lambda shape, dtype: [2, 1]) 423 424 # New vocab with elements in reverse order and one new element. 425 new_vocab_path = self._write_vocab(["orange", "apple", "banana"], 426 "new_vocab") 427 # New session and new graph. 428 with ops.Graph().as_default() as g: 429 with self.session(graph=g) as sess: 430 fruit_output_layer = variable_scope.get_variable( 431 "fruit_output_layer", 432 shape=[4, 3], 433 initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], 434 [0., 0., 0.]], 435 partitioner=lambda shape, dtype: [2, 1]) 436 ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, 437 current_vocab_size=3, 438 prev_ckpt=self.get_temp_dir(), 439 prev_vocab_path=prev_vocab_path, 440 axis=1) 441 self.evaluate(variables.global_variables_initializer()) 442 self.assertTrue( 443 isinstance(fruit_output_layer, variables.PartitionedVariable)) 444 fruit_output_layer_vars = fruit_output_layer._get_variable_list() 445 self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]], 446 fruit_output_layer_vars[0].eval(sess)) 447 self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]], 448 fruit_output_layer_vars[1].eval(sess)) 449 450 def testWarmStart_ListOfVariables(self): 451 # Save checkpoint from which to warm-start. 452 _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1], 453 initializer=ones()) 454 # Verify we initialized the values correctly. 455 self.assertAllEqual(np.ones([10, 1]), prev_int_val) 456 457 # New graph, new session with warm-starting. 458 with ops.Graph().as_default() as g: 459 with self.session(graph=g) as sess: 460 # Initialize with zeros. 461 var = variable_scope.get_variable( 462 "v1", 463 shape=[10, 1], 464 initializer=zeros()) 465 ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=[var]) 466 self.evaluate(variables.global_variables_initializer()) 467 # Verify weights were correctly warm-started (init overridden to ones). 468 self.assertAllEqual(var.eval(), prev_int_val) 469 470 def testWarmStart_ListOfStrings(self): 471 # Save checkpoint from which to warm-start. 472 _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1], 473 initializer=ones()) 474 # Verify we initialized the values correctly. 475 self.assertAllEqual(np.ones([10, 1]), prev_int_val) 476 477 # New graph, new session with warm-starting. 478 with ops.Graph().as_default() as g: 479 with self.session(graph=g) as sess: 480 # Initialize with zeros. 481 var = variable_scope.get_variable( 482 "v1", 483 shape=[10, 1], 484 initializer=zeros()) 485 ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=["v1"]) 486 self.evaluate(variables.global_variables_initializer()) 487 # Verify weights were correctly warm-started (init overridden to ones). 488 self.assertAllEqual(var.eval(), prev_int_val) 489 490 def testWarmStart_ListOfRegexes(self): 491 # Save checkpoint from which to warm-start. 492 [prev_v1_val, prev_v1_momentum_val, 493 prev_v2_val, _] = self._create_prev_run_vars( 494 var_names=["v1", "v1/Momentum", "v2", "v2/Momentum"], 495 shapes=[[10, 1]] * 4, 496 initializers=[ones()] * 4) 497 498 # New graph, new session with warm-starting. 499 with ops.Graph().as_default() as g: 500 with self.session(graph=g) as sess: 501 # Initialize with zeros. 502 v1 = variable_scope.get_variable( 503 "v1", 504 shape=[10, 1], 505 initializer=zeros()) 506 v1_momentum = variable_scope.get_variable( 507 "v1/Momentum", 508 shape=[10, 1], 509 initializer=zeros()) 510 v2 = variable_scope.get_variable( 511 "v2", 512 shape=[10, 1], 513 initializer=zeros()) 514 v2_momentum = variable_scope.get_variable( 515 "v2/Momentum", 516 shape=[10, 1], 517 initializer=zeros()) 518 ws_util.warm_start(self.get_temp_dir(), 519 # This warm-starts both v1 and v1/Momentum, but only 520 # v2 (and not v2/Momentum). 521 vars_to_warm_start=["v1", "v2[^/]"]) 522 self.evaluate(variables.global_variables_initializer()) 523 # Verify the selection of weights were correctly warm-started (init 524 # overridden to ones). 525 self.assertAllEqual(v1.eval(), prev_v1_val) 526 self.assertAllEqual(v1_momentum.eval(), prev_v1_momentum_val) 527 self.assertAllEqual(v2.eval(), prev_v2_val) 528 self.assertAllEqual(v2_momentum.eval(), np.zeros([10, 1])) 529 530 def testWarmStart_SparseColumnIntegerized(self): 531 # Create feature column. 532 sc_int = fc.categorical_column_with_identity("sc_int", num_buckets=10) 533 534 # Save checkpoint from which to warm-start. 535 _, prev_int_val = self._create_prev_run_var( 536 "linear_model/sc_int/weights", shape=[10, 1], initializer=ones()) 537 # Verify we initialized the values correctly. 538 self.assertAllEqual(np.ones([10, 1]), prev_int_val) 539 540 partitioner = lambda shape, dtype: [1] * len(shape) 541 # New graph, new session WITHOUT warm-starting. 542 with ops.Graph().as_default() as g: 543 with self.session(graph=g) as sess: 544 cols_to_vars = self._create_linear_model([sc_int], partitioner) 545 self.evaluate(variables.global_variables_initializer()) 546 # Without warm-starting, the weights should be initialized using default 547 # initializer (which is init_ops.zeros_initializer). 548 self._assert_cols_to_vars(cols_to_vars, {sc_int: [np.zeros([10, 1])]}, 549 sess) 550 551 # New graph, new session with warm-starting. 552 with ops.Graph().as_default() as g: 553 with self.session(graph=g) as sess: 554 cols_to_vars = self._create_linear_model([sc_int], partitioner) 555 ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=".*sc_int.*") 556 self.evaluate(variables.global_variables_initializer()) 557 # Verify weights were correctly warm-started. 558 self._assert_cols_to_vars(cols_to_vars, {sc_int: [prev_int_val]}, sess) 559 560 def testWarmStart_SparseColumnHashed(self): 561 # Create feature column. 562 sc_hash = fc.categorical_column_with_hash_bucket( 563 "sc_hash", hash_bucket_size=15) 564 565 # Save checkpoint from which to warm-start. 566 _, prev_hash_val = self._create_prev_run_var( 567 "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms()) 568 569 partitioner = lambda shape, dtype: [1] * len(shape) 570 # New graph, new session WITHOUT warm-starting. 571 with ops.Graph().as_default() as g: 572 with self.session(graph=g) as sess: 573 cols_to_vars = self._create_linear_model([sc_hash], partitioner) 574 self.evaluate(variables.global_variables_initializer()) 575 # Without warm-starting, the weights should be initialized using default 576 # initializer (which is init_ops.zeros_initializer). 577 self._assert_cols_to_vars(cols_to_vars, {sc_hash: [np.zeros([15, 1])]}, 578 sess) 579 580 # New graph, new session with warm-starting. 581 with ops.Graph().as_default() as g: 582 with self.session(graph=g) as sess: 583 cols_to_vars = self._create_linear_model([sc_hash], partitioner) 584 ws_util.warm_start( 585 self.get_temp_dir(), vars_to_warm_start=".*sc_hash.*") 586 self.evaluate(variables.global_variables_initializer()) 587 # Verify weights were correctly warm-started. 588 self._assert_cols_to_vars(cols_to_vars, {sc_hash: [prev_hash_val]}, 589 sess) 590 591 def testWarmStart_SparseColumnVocabulary(self): 592 # Create vocab for sparse column "sc_vocab". 593 vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], 594 "vocab") 595 # Create feature column. 596 sc_vocab = fc.categorical_column_with_vocabulary_file( 597 "sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4) 598 599 # Save checkpoint from which to warm-start. 600 _, prev_vocab_val = self._create_prev_run_var( 601 "linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones()) 602 603 partitioner = lambda shape, dtype: [1] * len(shape) 604 # New graph, new session WITHOUT warm-starting. 605 with ops.Graph().as_default() as g: 606 with self.session(graph=g) as sess: 607 cols_to_vars = self._create_linear_model([sc_vocab], partitioner) 608 self.evaluate(variables.global_variables_initializer()) 609 # Without warm-starting, the weights should be initialized using default 610 # initializer (which is init_ops.zeros_initializer). 611 self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]}, 612 sess) 613 614 # New graph, new session with warm-starting. 615 with ops.Graph().as_default() as g: 616 with self.session(graph=g) as sess: 617 cols_to_vars = self._create_linear_model([sc_vocab], partitioner) 618 # Since old vocab is not explicitly set in WarmStartSettings, the old 619 # vocab is assumed to be same as new vocab. 620 ws_util.warm_start( 621 self.get_temp_dir(), vars_to_warm_start=".*sc_vocab.*") 622 self.evaluate(variables.global_variables_initializer()) 623 # Verify weights were correctly warm-started. 624 self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]}, 625 sess) 626 627 def testWarmStart_ExplicitCheckpointFile(self): 628 # Create vocab for sparse column "sc_vocab". 629 vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], 630 "vocab") 631 # Create feature column. 632 sc_vocab = fc.categorical_column_with_vocabulary_file( 633 "sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4) 634 635 # Save checkpoint from which to warm-start. 636 _, prev_vocab_val = self._create_prev_run_var( 637 "linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones()) 638 639 partitioner = lambda shape, dtype: [1] * len(shape) 640 # New graph, new session WITHOUT warm-starting. 641 with ops.Graph().as_default() as g: 642 with self.session(graph=g) as sess: 643 cols_to_vars = self._create_linear_model([sc_vocab], partitioner) 644 self.evaluate(variables.global_variables_initializer()) 645 # Without warm-starting, the weights should be initialized using default 646 # initializer (which is init_ops.zeros_initializer). 647 self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]}, 648 sess) 649 650 # New graph, new session with warm-starting. 651 with ops.Graph().as_default() as g: 652 with self.session(graph=g) as sess: 653 cols_to_vars = self._create_linear_model([sc_vocab], partitioner) 654 # Since old vocab is not explicitly set in WarmStartSettings, the old 655 # vocab is assumed to be same as new vocab. 656 ws_util.warm_start( 657 # Explicitly provide the file prefix instead of just the dir. 658 os.path.join(self.get_temp_dir(), "model-0"), 659 vars_to_warm_start=".*sc_vocab.*") 660 self.evaluate(variables.global_variables_initializer()) 661 # Verify weights were correctly warm-started. 662 self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]}, 663 sess) 664 665 def testWarmStart_SparseColumnVocabularyConstrainedVocabSizes(self): 666 # Create old vocabulary, and use a size smaller than the total number of 667 # entries. 668 old_vocab_path = self._write_vocab(["apple", "guava", "banana"], 669 "old_vocab") 670 old_vocab_size = 2 # ['apple', 'guava'] 671 672 # Create new vocab for sparse column "sc_vocab". 673 current_vocab_path = self._write_vocab( 674 ["apple", "banana", "guava", "orange"], "current_vocab") 675 # Create feature column. Only use 2 of the actual entries, resulting in 676 # ['apple', 'banana'] for the new vocabulary. 677 sc_vocab = fc.categorical_column_with_vocabulary_file( 678 "sc_vocab", vocabulary_file=current_vocab_path, vocabulary_size=2) 679 680 # Save checkpoint from which to warm-start. 681 self._create_prev_run_var( 682 "linear_model/sc_vocab/weights", shape=[2, 1], initializer=ones()) 683 684 partitioner = lambda shape, dtype: [1] * len(shape) 685 # New graph, new session WITHOUT warm-starting. 686 with ops.Graph().as_default() as g: 687 with self.session(graph=g) as sess: 688 cols_to_vars = self._create_linear_model([sc_vocab], partitioner) 689 self.evaluate(variables.global_variables_initializer()) 690 # Without warm-starting, the weights should be initialized using default 691 # initializer (which is init_ops.zeros_initializer). 692 self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([2, 1])]}, 693 sess) 694 695 # New graph, new session with warm-starting. 696 with ops.Graph().as_default() as g: 697 with self.session(graph=g) as sess: 698 cols_to_vars = self._create_linear_model([sc_vocab], partitioner) 699 vocab_info = ws_util.VocabInfo( 700 new_vocab=sc_vocab.vocabulary_file, 701 new_vocab_size=sc_vocab.vocabulary_size, 702 num_oov_buckets=sc_vocab.num_oov_buckets, 703 old_vocab=old_vocab_path, 704 old_vocab_size=old_vocab_size) 705 ws_util.warm_start( 706 ckpt_to_initialize_from=self.get_temp_dir(), 707 vars_to_warm_start=".*sc_vocab.*", 708 var_name_to_vocab_info={ 709 "linear_model/sc_vocab/weights": vocab_info 710 }) 711 self.evaluate(variables.global_variables_initializer()) 712 # Verify weights were correctly warm-started. 'banana' isn't in the 713 # first two entries of the old vocabulary, so it's newly initialized. 714 self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [[[1], [0]]]}, sess) 715 716 def testWarmStart_BucketizedColumn(self): 717 # Create feature column. 718 real = fc.numeric_column("real") 719 real_bucket = fc.bucketized_column(real, boundaries=[0., 1., 2., 3.]) 720 721 # Save checkpoint from which to warm-start. 722 _, prev_bucket_val = self._create_prev_run_var( 723 "linear_model/real_bucketized/weights", 724 shape=[5, 1], 725 initializer=norms()) 726 727 partitioner = lambda shape, dtype: [1] * len(shape) 728 # New graph, new session WITHOUT warm-starting. 729 with ops.Graph().as_default() as g: 730 with self.session(graph=g) as sess: 731 cols_to_vars = self._create_linear_model([real_bucket], partitioner) 732 self.evaluate(variables.global_variables_initializer()) 733 # Without warm-starting, the weights should be initialized using default 734 # initializer (which is init_ops.zeros_initializer). 735 self._assert_cols_to_vars(cols_to_vars, 736 {real_bucket: [np.zeros([5, 1])]}, sess) 737 738 # New graph, new session with warm-starting. 739 with ops.Graph().as_default() as g: 740 with self.session(graph=g) as sess: 741 cols_to_vars = self._create_linear_model([real_bucket], partitioner) 742 ws_util.warm_start( 743 self.get_temp_dir(), vars_to_warm_start=".*real_bucketized.*") 744 self.evaluate(variables.global_variables_initializer()) 745 # Verify weights were correctly warm-started. 746 self._assert_cols_to_vars(cols_to_vars, 747 {real_bucket: [prev_bucket_val]}, sess) 748 749 def testWarmStart_MultipleCols(self): 750 # Create vocab for sparse column "sc_vocab". 751 vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], 752 "vocab") 753 754 # Create feature columns. 755 sc_int = fc.categorical_column_with_identity("sc_int", num_buckets=10) 756 sc_hash = fc.categorical_column_with_hash_bucket( 757 "sc_hash", hash_bucket_size=15) 758 sc_keys = fc.categorical_column_with_vocabulary_list( 759 "sc_keys", vocabulary_list=["a", "b", "c", "e"]) 760 sc_vocab = fc.categorical_column_with_vocabulary_file( 761 "sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4) 762 real = fc.numeric_column("real") 763 real_bucket = fc.bucketized_column(real, boundaries=[0., 1., 2., 3.]) 764 cross = fc.crossed_column([sc_keys, sc_vocab], hash_bucket_size=20) 765 all_linear_cols = [sc_int, sc_hash, sc_keys, sc_vocab, real_bucket, cross] 766 767 # Save checkpoint from which to warm-start. Also create a bias variable, 768 # so we can check that it's also warm-started. 769 with ops.Graph().as_default() as g: 770 with self.session(graph=g) as sess: 771 sc_int_weights = variable_scope.get_variable( 772 "linear_model/sc_int/weights", shape=[10, 1], initializer=ones()) 773 sc_hash_weights = variable_scope.get_variable( 774 "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms()) 775 sc_keys_weights = variable_scope.get_variable( 776 "linear_model/sc_keys/weights", shape=[4, 1], initializer=rand()) 777 sc_vocab_weights = variable_scope.get_variable( 778 "linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones()) 779 real_bucket_weights = variable_scope.get_variable( 780 "linear_model/real_bucketized/weights", 781 shape=[5, 1], 782 initializer=norms()) 783 cross_weights = variable_scope.get_variable( 784 "linear_model/sc_keys_X_sc_vocab/weights", 785 shape=[20, 1], 786 initializer=rand()) 787 bias = variable_scope.get_variable( 788 "linear_model/bias_weights", 789 shape=[1], 790 initializer=rand()) 791 self._write_checkpoint(sess) 792 (prev_int_val, prev_hash_val, prev_keys_val, prev_vocab_val, 793 prev_bucket_val, prev_cross_val, prev_bias_val) = sess.run([ 794 sc_int_weights, sc_hash_weights, sc_keys_weights, sc_vocab_weights, 795 real_bucket_weights, cross_weights, bias 796 ]) 797 798 partitioner = lambda shape, dtype: [1] * len(shape) 799 # New graph, new session WITHOUT warm-starting. 800 with ops.Graph().as_default() as g: 801 with self.session(graph=g) as sess: 802 cols_to_vars = self._create_linear_model(all_linear_cols, partitioner) 803 self.evaluate(variables.global_variables_initializer()) 804 # Without warm-starting, all weights should be initialized using default 805 # initializer (which is init_ops.zeros_initializer). 806 self._assert_cols_to_vars(cols_to_vars, { 807 sc_int: [np.zeros([10, 1])], 808 sc_hash: [np.zeros([15, 1])], 809 sc_keys: [np.zeros([4, 1])], 810 sc_vocab: [np.zeros([4, 1])], 811 real_bucket: [np.zeros([5, 1])], 812 cross: [np.zeros([20, 1])], 813 }, sess) 814 815 # New graph, new session with warm-starting. 816 with ops.Graph().as_default() as g: 817 with self.session(graph=g) as sess: 818 cols_to_vars = self._create_linear_model(all_linear_cols, partitioner) 819 vocab_info = ws_util.VocabInfo( 820 new_vocab=sc_vocab.vocabulary_file, 821 new_vocab_size=sc_vocab.vocabulary_size, 822 num_oov_buckets=sc_vocab.num_oov_buckets, 823 old_vocab=vocab_path) 824 ws_util.warm_start( 825 self.get_temp_dir(), 826 var_name_to_vocab_info={ 827 "linear_model/sc_vocab/weights": vocab_info 828 }) 829 self.evaluate(variables.global_variables_initializer()) 830 # Verify weights were correctly warm-started. 831 self._assert_cols_to_vars(cols_to_vars, { 832 sc_int: [prev_int_val], 833 sc_hash: [prev_hash_val], 834 sc_keys: [prev_keys_val], 835 sc_vocab: [prev_vocab_val], 836 real_bucket: [prev_bucket_val], 837 cross: [prev_cross_val], 838 "bias": [prev_bias_val], 839 }, sess) 840 841 def testWarmStartMoreSettings(self): 842 # Create old and new vocabs for sparse column "sc_vocab". 843 prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], 844 "old_vocab") 845 new_vocab_path = self._write_vocab( 846 ["orange", "guava", "banana", "apple", "raspberry", 847 "blueberry"], "new_vocab") 848 # Create feature columns. 849 sc_hash = fc.categorical_column_with_hash_bucket( 850 "sc_hash", hash_bucket_size=15) 851 sc_keys = fc.categorical_column_with_vocabulary_list( 852 "sc_keys", vocabulary_list=["a", "b", "c", "e"]) 853 sc_vocab = fc.categorical_column_with_vocabulary_file( 854 "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6) 855 all_linear_cols = [sc_hash, sc_keys, sc_vocab] 856 857 # Save checkpoint from which to warm-start. 858 with ops.Graph().as_default() as g: 859 with self.session(graph=g) as sess: 860 variable_scope.get_variable( 861 "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms()) 862 sc_keys_weights = variable_scope.get_variable( 863 "some_other_name", shape=[4, 1], initializer=rand()) 864 variable_scope.get_variable( 865 "linear_model/sc_vocab/weights", 866 initializer=[[0.5], [1.], [2.], [3.]]) 867 self._write_checkpoint(sess) 868 prev_keys_val = self.evaluate(sc_keys_weights) 869 870 def _partitioner(shape, dtype): # pylint:disable=unused-argument 871 # Partition each var into 2 equal slices. 872 partitions = [1] * len(shape) 873 partitions[0] = min(2, shape.dims[0].value) 874 return partitions 875 876 # New graph, new session with warm-starting. 877 with ops.Graph().as_default() as g: 878 with self.session(graph=g) as sess: 879 cols_to_vars = self._create_linear_model(all_linear_cols, _partitioner) 880 vocab_info = ws_util.VocabInfo( 881 new_vocab=sc_vocab.vocabulary_file, 882 new_vocab_size=sc_vocab.vocabulary_size, 883 num_oov_buckets=sc_vocab.num_oov_buckets, 884 old_vocab=prev_vocab_path) 885 ws_util.warm_start( 886 self.get_temp_dir(), 887 vars_to_warm_start=".*(sc_keys|sc_vocab).*", 888 var_name_to_vocab_info={ 889 ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info 890 }, 891 var_name_to_prev_var_name={ 892 ws_util._infer_var_name(cols_to_vars[sc_keys]): 893 "some_other_name" 894 }) 895 self.evaluate(variables.global_variables_initializer()) 896 # Verify weights were correctly warm-started. Var corresponding to 897 # sc_hash should not be warm-started. Var corresponding to sc_vocab 898 # should be correctly warm-started after vocab remapping. 899 self._assert_cols_to_vars(cols_to_vars, { 900 sc_keys: 901 np.split(prev_keys_val, 2), 902 sc_hash: [np.zeros([8, 1]), np.zeros([7, 1])], 903 sc_vocab: [ 904 np.array([[3.], [2.], [1.]]), 905 np.array([[0.5], [0.], [0.]]) 906 ] 907 }, sess) 908 909 def testWarmStartMoreSettingsNoPartitioning(self): 910 # Create old and new vocabs for sparse column "sc_vocab". 911 prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], 912 "old_vocab") 913 new_vocab_path = self._write_vocab( 914 ["orange", "guava", "banana", "apple", "raspberry", 915 "blueberry"], "new_vocab") 916 # Create feature columns. 917 sc_hash = fc.categorical_column_with_hash_bucket( 918 "sc_hash", hash_bucket_size=15) 919 sc_keys = fc.categorical_column_with_vocabulary_list( 920 "sc_keys", vocabulary_list=["a", "b", "c", "e"]) 921 sc_vocab = fc.categorical_column_with_vocabulary_file( 922 "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6) 923 all_linear_cols = [sc_hash, sc_keys, sc_vocab] 924 925 # Save checkpoint from which to warm-start. 926 with ops.Graph().as_default() as g: 927 with self.session(graph=g) as sess: 928 variable_scope.get_variable( 929 "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms()) 930 sc_keys_weights = variable_scope.get_variable( 931 "some_other_name", shape=[4, 1], initializer=rand()) 932 variable_scope.get_variable( 933 "linear_model/sc_vocab/weights", 934 initializer=[[0.5], [1.], [2.], [3.]]) 935 self._write_checkpoint(sess) 936 prev_keys_val = self.evaluate(sc_keys_weights) 937 938 # New graph, new session with warm-starting. 939 with ops.Graph().as_default() as g: 940 with self.session(graph=g) as sess: 941 cols_to_vars = self._create_linear_model(all_linear_cols, 942 partitioner=None) 943 vocab_info = ws_util.VocabInfo( 944 new_vocab=sc_vocab.vocabulary_file, 945 new_vocab_size=sc_vocab.vocabulary_size, 946 num_oov_buckets=sc_vocab.num_oov_buckets, 947 old_vocab=prev_vocab_path) 948 ws_util.warm_start( 949 self.get_temp_dir(), 950 vars_to_warm_start=".*(sc_keys|sc_vocab).*", 951 var_name_to_vocab_info={ 952 ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info 953 }, 954 var_name_to_prev_var_name={ 955 ws_util._infer_var_name(cols_to_vars[sc_keys]): 956 "some_other_name" 957 }) 958 self.evaluate(variables.global_variables_initializer()) 959 # Verify weights were correctly warm-started. Var corresponding to 960 # sc_hash should not be warm-started. Var corresponding to sc_vocab 961 # should be correctly warm-started after vocab remapping. 962 self._assert_cols_to_vars(cols_to_vars, { 963 sc_keys: [prev_keys_val], 964 sc_hash: [np.zeros([15, 1])], 965 sc_vocab: [np.array([[3.], [2.], [1.], [0.5], [0.], [0.]])] 966 }, sess) 967 968 def testWarmStartVarsToWarmstartIsNone(self): 969 # Create old and new vocabs for sparse column "sc_vocab". 970 prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], 971 "old_vocab") 972 new_vocab_path = self._write_vocab( 973 ["orange", "guava", "banana", "apple", "raspberry", 974 "blueberry"], "new_vocab") 975 # Create feature columns. 976 sc_hash = fc.categorical_column_with_hash_bucket( 977 "sc_hash", hash_bucket_size=15) 978 sc_keys = fc.categorical_column_with_vocabulary_list( 979 "sc_keys", vocabulary_list=["a", "b", "c", "e"]) 980 sc_vocab = fc.categorical_column_with_vocabulary_file( 981 "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6) 982 all_linear_cols = [sc_hash, sc_keys, sc_vocab] 983 984 # Save checkpoint from which to warm-start. 985 with ops.Graph().as_default() as g: 986 with self.session(graph=g) as sess: 987 variable_scope.get_variable( 988 "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms()) 989 variable_scope.get_variable( 990 "some_other_name", shape=[4, 1], initializer=rand()) 991 variable_scope.get_variable( 992 "linear_model/sc_vocab/weights", 993 initializer=[[0.5], [1.], [2.], [3.]]) 994 self._write_checkpoint(sess) 995 996 def _partitioner(shape, dtype): # pylint:disable=unused-argument 997 # Partition each var into 2 equal slices. 998 partitions = [1] * len(shape) 999 partitions[0] = min(2, shape.dims[0].value) 1000 return partitions 1001 1002 # New graph, new session with warm-starting. 1003 with ops.Graph().as_default() as g: 1004 with self.session(graph=g) as sess: 1005 cols_to_vars = self._create_linear_model(all_linear_cols, _partitioner) 1006 vocab_info = ws_util.VocabInfo( 1007 new_vocab=sc_vocab.vocabulary_file, 1008 new_vocab_size=sc_vocab.vocabulary_size, 1009 num_oov_buckets=sc_vocab.num_oov_buckets, 1010 old_vocab=prev_vocab_path) 1011 ws_util.warm_start( 1012 self.get_temp_dir(), 1013 # The special value of None here will ensure that only the variable 1014 # specified in var_name_to_vocab_info (sc_vocab embedding) is 1015 # warm-started. 1016 vars_to_warm_start=None, 1017 var_name_to_vocab_info={ 1018 ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info 1019 }, 1020 # Even though this is provided, the None value for 1021 # vars_to_warm_start overrides the logic, and this will not be 1022 # warm-started. 1023 var_name_to_prev_var_name={ 1024 ws_util._infer_var_name(cols_to_vars[sc_keys]): 1025 "some_other_name" 1026 }) 1027 self.evaluate(variables.global_variables_initializer()) 1028 # Verify weights were correctly warm-started. Var corresponding to 1029 # sc_vocab should be correctly warm-started after vocab remapping, 1030 # and neither of the other two should be warm-started.. 1031 self._assert_cols_to_vars(cols_to_vars, { 1032 sc_keys: [np.zeros([2, 1]), np.zeros([2, 1])], 1033 sc_hash: [np.zeros([8, 1]), np.zeros([7, 1])], 1034 sc_vocab: [ 1035 np.array([[3.], [2.], [1.]]), 1036 np.array([[0.5], [0.], [0.]]) 1037 ] 1038 }, sess) 1039 1040 def testWarmStartEmbeddingColumn(self): 1041 # Create old and new vocabs for embedding column "sc_vocab". 1042 prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], 1043 "old_vocab") 1044 new_vocab_path = self._write_vocab( 1045 ["orange", "guava", "banana", "apple", "raspberry", "blueberry"], 1046 "new_vocab") 1047 1048 # Save checkpoint from which to warm-start. 1049 with ops.Graph().as_default() as g: 1050 with self.session(graph=g) as sess: 1051 variable_scope.get_variable( 1052 "input_layer/sc_vocab_embedding/embedding_weights", 1053 initializer=[[0.5, 0.4], [1., 1.1], [2., 2.2], [3., 3.3]]) 1054 self._write_checkpoint(sess) 1055 1056 def _partitioner(shape, dtype): # pylint:disable=unused-argument 1057 # Partition each var into 2 equal slices. 1058 partitions = [1] * len(shape) 1059 partitions[0] = min(2, shape.dims[0].value) 1060 return partitions 1061 1062 # Create feature columns. 1063 sc_vocab = fc.categorical_column_with_vocabulary_file( 1064 "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6) 1065 emb_vocab_column = fc.embedding_column( 1066 categorical_column=sc_vocab, 1067 dimension=2) 1068 all_deep_cols = [emb_vocab_column] 1069 # New graph, new session with warm-starting. 1070 with ops.Graph().as_default() as g: 1071 with self.session(graph=g) as sess: 1072 cols_to_vars = {} 1073 with variable_scope.variable_scope("", partitioner=_partitioner): 1074 # Create the variables. 1075 fc.input_layer( 1076 features=self._create_dummy_inputs(), 1077 feature_columns=all_deep_cols, 1078 cols_to_vars=cols_to_vars) 1079 vocab_info = ws_util.VocabInfo( 1080 new_vocab=sc_vocab.vocabulary_file, 1081 new_vocab_size=sc_vocab.vocabulary_size, 1082 num_oov_buckets=sc_vocab.num_oov_buckets, 1083 old_vocab=prev_vocab_path, 1084 # Can't use constant_initializer with load_and_remap. In practice, 1085 # use a truncated normal initializer. 1086 backup_initializer=init_ops.random_uniform_initializer( 1087 minval=0.42, maxval=0.42)) 1088 ws_util.warm_start( 1089 self.get_temp_dir(), 1090 var_name_to_vocab_info={ 1091 ws_util._infer_var_name(cols_to_vars[emb_vocab_column]): 1092 vocab_info 1093 }) 1094 self.evaluate(variables.global_variables_initializer()) 1095 # Verify weights were correctly warm-started. Var corresponding to 1096 # emb_vocab_column should be correctly warm-started after vocab 1097 # remapping. Missing values are filled in with the EmbeddingColumn's 1098 # initializer. 1099 self._assert_cols_to_vars( 1100 cols_to_vars, { 1101 emb_vocab_column: [ 1102 np.array([[3., 3.3], [2., 2.2], [1., 1.1]]), 1103 np.array([[0.5, 0.4], [0.42, 0.42], [0.42, 0.42]]) 1104 ] 1105 }, sess) 1106 1107 def testWarmStartEmbeddingColumnLinearModel(self): 1108 # Create old and new vocabs for embedding column "sc_vocab". 1109 prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], 1110 "old_vocab") 1111 new_vocab_path = self._write_vocab( 1112 ["orange", "guava", "banana", "apple", "raspberry", "blueberry"], 1113 "new_vocab") 1114 1115 # Save checkpoint from which to warm-start. 1116 with ops.Graph().as_default() as g: 1117 with self.session(graph=g) as sess: 1118 variable_scope.get_variable( 1119 "linear_model/sc_vocab_embedding/embedding_weights", 1120 initializer=[[0.5, 0.4], [1., 1.1], [2., 2.2], [3., 3.3]]) 1121 variable_scope.get_variable( 1122 "linear_model/sc_vocab_embedding/weights", 1123 initializer=[[0.69], [0.71]]) 1124 self._write_checkpoint(sess) 1125 1126 def _partitioner(shape, dtype): # pylint:disable=unused-argument 1127 # Partition each var into 2 equal slices. 1128 partitions = [1] * len(shape) 1129 partitions[0] = min(2, shape.dims[0].value) 1130 return partitions 1131 1132 # Create feature columns. 1133 sc_vocab = fc.categorical_column_with_vocabulary_file( 1134 "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6) 1135 emb_vocab = fc.embedding_column( 1136 categorical_column=sc_vocab, 1137 dimension=2) 1138 all_deep_cols = [emb_vocab] 1139 # New graph, new session with warm-starting. 1140 with ops.Graph().as_default() as g: 1141 with self.session(graph=g) as sess: 1142 cols_to_vars = {} 1143 with variable_scope.variable_scope("", partitioner=_partitioner): 1144 # Create the variables. 1145 fc.linear_model( 1146 features=self._create_dummy_inputs(), 1147 feature_columns=all_deep_cols, 1148 cols_to_vars=cols_to_vars) 1149 1150 # Construct the vocab_info for the embedding weight. 1151 vocab_info = ws_util.VocabInfo( 1152 new_vocab=sc_vocab.vocabulary_file, 1153 new_vocab_size=sc_vocab.vocabulary_size, 1154 num_oov_buckets=sc_vocab.num_oov_buckets, 1155 old_vocab=prev_vocab_path, 1156 # Can't use constant_initializer with load_and_remap. In practice, 1157 # use a truncated normal initializer. 1158 backup_initializer=init_ops.random_uniform_initializer( 1159 minval=0.42, maxval=0.42)) 1160 ws_util.warm_start( 1161 self.get_temp_dir(), 1162 vars_to_warm_start=".*sc_vocab.*", 1163 var_name_to_vocab_info={ 1164 "linear_model/sc_vocab_embedding/embedding_weights": vocab_info 1165 }) 1166 self.evaluate(variables.global_variables_initializer()) 1167 # Verify weights were correctly warm-started. Var corresponding to 1168 # emb_vocab should be correctly warm-started after vocab remapping. 1169 # Missing values are filled in with the EmbeddingColumn's initializer. 1170 self._assert_cols_to_vars( 1171 cols_to_vars, 1172 { 1173 emb_vocab: [ 1174 # linear weights part 0. 1175 np.array([[0.69]]), 1176 # linear weights part 1. 1177 np.array([[0.71]]), 1178 # embedding_weights part 0. 1179 np.array([[3., 3.3], [2., 2.2], [1., 1.1]]), 1180 # embedding_weights part 1. 1181 np.array([[0.5, 0.4], [0.42, 0.42], [0.42, 0.42]]) 1182 ] 1183 }, 1184 sess) 1185 1186 def testErrorConditions(self): 1187 x = variable_scope.get_variable( 1188 "x", 1189 shape=[4, 1], 1190 initializer=ones(), 1191 partitioner=lambda shape, dtype: [2, 1]) 1192 1193 # List of PartitionedVariable is invalid type when warm-starting with vocab. 1194 self.assertRaises(TypeError, ws_util._warm_start_var_with_vocab, [x], 1195 "/tmp", 5, "/tmp", "/tmp") 1196 1197 # Unused variable names raises ValueError. 1198 with ops.Graph().as_default(): 1199 with self.cached_session() as sess: 1200 x = variable_scope.get_variable( 1201 "x", 1202 shape=[4, 1], 1203 initializer=ones(), 1204 partitioner=lambda shape, dtype: [2, 1]) 1205 self._write_checkpoint(sess) 1206 1207 self.assertRaises( 1208 ValueError, 1209 ws_util.warm_start, 1210 self.get_temp_dir(), 1211 var_name_to_vocab_info={"y": ws_util.VocabInfo("", 1, 0, "")}) 1212 self.assertRaises( 1213 ValueError, 1214 ws_util.warm_start, 1215 self.get_temp_dir(), 1216 var_name_to_prev_var_name={"y": "y2"}) 1217 1218 1219if __name__ == "__main__": 1220 test.main() 1221