• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for warm_starting_util."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import numpy as np
23import six
24
25from tensorflow.python.feature_column import feature_column_lib as fc
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import init_ops
30from tensorflow.python.ops import variable_scope
31from tensorflow.python.ops import variables
32from tensorflow.python.platform import test
33from tensorflow.python.training import checkpoint_utils
34from tensorflow.python.training import saver as saver_lib
35from tensorflow.python.training import warm_starting_util as ws_util
36
37ones = init_ops.ones_initializer
38norms = init_ops.truncated_normal_initializer
39rand = init_ops.random_uniform_initializer
40zeros = init_ops.zeros_initializer
41
42
43class WarmStartingUtilTest(test.TestCase):
44
45  def _write_vocab(self, string_values, file_name):
46    vocab_file = os.path.join(self.get_temp_dir(), file_name)
47    with open(vocab_file, "w") as f:
48      f.write("\n".join(string_values))
49    return vocab_file
50
51  def _write_checkpoint(self, sess):
52    self.evaluate(variables.global_variables_initializer())
53    saver = saver_lib.Saver()
54    ckpt_prefix = os.path.join(self.get_temp_dir(), "model")
55    saver.save(sess, ckpt_prefix, global_step=0)
56
57  def _create_prev_run_var(self,
58                           var_name,
59                           shape=None,
60                           initializer=None,
61                           partitioner=None):
62    with ops.Graph().as_default() as g:
63      with self.session(graph=g) as sess:
64        var = variable_scope.get_variable(
65            var_name,
66            shape=shape,
67            initializer=initializer,
68            partitioner=partitioner)
69        self._write_checkpoint(sess)
70        if partitioner:
71          self.assertTrue(isinstance(var, variables.PartitionedVariable))
72          var = var._get_variable_list()
73        return var, self.evaluate(var)
74
75  def _create_prev_run_vars(self,
76                            var_names,
77                            shapes,
78                            initializers):
79    with ops.Graph().as_default() as g:
80      with self.session(graph=g) as sess:
81        all_vars = []
82        for var_name, shape, initializer in zip(var_names, shapes,
83                                                initializers):
84          all_vars.append(variable_scope.get_variable(
85              var_name,
86              shape=shape,
87              initializer=initializer))
88        self._write_checkpoint(sess)
89        return [self.evaluate(var) for var in all_vars]
90
91  def _create_dummy_inputs(self):
92    return {
93        "sc_int": array_ops.sparse_placeholder(dtypes.int32),
94        "sc_hash": array_ops.sparse_placeholder(dtypes.string),
95        "sc_keys": array_ops.sparse_placeholder(dtypes.string),
96        "sc_vocab": array_ops.sparse_placeholder(dtypes.string),
97        "real": array_ops.placeholder(dtypes.float32)
98    }
99
100  def _create_linear_model(self, feature_cols, partitioner):
101    cols_to_vars = {}
102    with variable_scope.variable_scope("", partitioner=partitioner):
103      # Create the variables.
104      fc.linear_model(
105          features=self._create_dummy_inputs(),
106          feature_columns=feature_cols,
107          units=1,
108          cols_to_vars=cols_to_vars)
109    # Return a dictionary mapping each column to its variable.
110    return cols_to_vars
111
112  def _assert_cols_to_vars(self, cols_to_vars, cols_to_expected_values, sess):
113    for col, expected_values in six.iteritems(cols_to_expected_values):
114      for i, var in enumerate(cols_to_vars[col]):
115        self.assertAllClose(expected_values[i], var.eval(sess))
116
117  def testWarmStartVar(self):
118    _, prev_val = self._create_prev_run_var(
119        "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
120
121    with ops.Graph().as_default() as g:
122      with self.session(graph=g) as sess:
123        fruit_weights = variable_scope.get_variable(
124            "fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
125        prev_tensor_name, var = ws_util._get_var_info(fruit_weights)
126        checkpoint_utils.init_from_checkpoint(self.get_temp_dir(),
127                                              {prev_tensor_name: var})
128        self.evaluate(variables.global_variables_initializer())
129        self.assertAllClose(prev_val, fruit_weights.eval(sess))
130
131  def testWarmStartVarPrevVarPartitioned(self):
132    _, weights = self._create_prev_run_var(
133        "fruit_weights",
134        shape=[4, 1],
135        initializer=[[0.5], [1.], [1.5], [2.]],
136        partitioner=lambda shape, dtype: [2, 1])
137    prev_val = np.concatenate([weights[0], weights[1]], axis=0)
138
139    with ops.Graph().as_default() as g:
140      with self.session(graph=g) as sess:
141        fruit_weights = variable_scope.get_variable(
142            "fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
143        prev_tensor_name, var = ws_util._get_var_info(fruit_weights)
144        checkpoint_utils.init_from_checkpoint(self.get_temp_dir(),
145                                              {prev_tensor_name: var})
146        self.evaluate(variables.global_variables_initializer())
147        self.assertAllClose(prev_val, fruit_weights.eval(sess))
148
149  def testWarmStartVarCurrentVarPartitioned(self):
150    _, prev_val = self._create_prev_run_var(
151        "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
152
153    with ops.Graph().as_default() as g:
154      with self.session(graph=g) as sess:
155        fruit_weights = variable_scope.get_variable(
156            "fruit_weights",
157            shape=[4, 1],
158            initializer=[[0.], [0.], [0.], [0.]],
159            partitioner=lambda shape, dtype: [2, 1])
160        self.assertTrue(
161            isinstance(fruit_weights, variables.PartitionedVariable))
162        prev_tensor_name, var = ws_util._get_var_info(fruit_weights)
163        checkpoint_utils.init_from_checkpoint(self.get_temp_dir(),
164                                              {prev_tensor_name: var})
165        self.evaluate(variables.global_variables_initializer())
166        fruit_weights = fruit_weights._get_variable_list()
167        new_val = np.concatenate(
168            [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
169        self.assertAllClose(prev_val, new_val)
170
171  def testWarmStartVarBothVarsPartitioned(self):
172    _, weights = self._create_prev_run_var(
173        "old_scope/fruit_weights",
174        shape=[4, 1],
175        initializer=[[0.5], [1.], [1.5], [2.]],
176        partitioner=lambda shape, dtype: [2, 1])
177    prev_val = np.concatenate([weights[0], weights[1]], axis=0)
178    # New session and new graph.
179    with ops.Graph().as_default() as g:
180      with self.session(graph=g) as sess:
181        fruit_weights = variable_scope.get_variable(
182            "new_scope/fruit_weights",
183            shape=[4, 1],
184            initializer=[[0.], [0.], [0.], [0.]],
185            partitioner=lambda shape, dtype: [2, 1])
186        self.assertTrue(
187            isinstance(fruit_weights, variables.PartitionedVariable))
188        prev_tensor_name, var = ws_util._get_var_info(
189            fruit_weights, prev_tensor_name="old_scope/fruit_weights")
190        checkpoint_utils.init_from_checkpoint(self.get_temp_dir(),
191                                              {prev_tensor_name: var})
192        self.evaluate(variables.global_variables_initializer())
193        fruit_weights = fruit_weights._get_variable_list()
194        new_val = np.concatenate(
195            [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
196        self.assertAllClose(prev_val, new_val)
197
198  def testWarmStartVarWithVocab(self):
199    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
200                                        "old_vocab")
201    self._create_prev_run_var(
202        "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
203
204    # New vocab with elements in reverse order and one new element.
205    new_vocab_path = self._write_vocab(
206        ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
207    # New session and new graph.
208    with ops.Graph().as_default() as g:
209      with self.session(graph=g) as sess:
210        fruit_weights = variable_scope.get_variable(
211            "fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]])
212        ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
213                                           self.get_temp_dir(), prev_vocab_path)
214        self.evaluate(variables.global_variables_initializer())
215        self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
216                            fruit_weights.eval(sess))
217
218  def testWarmStartVarWithColumnVocab(self):
219    prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
220    self._create_prev_run_var(
221        "fruit_output_layer",
222        initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
223
224    # New vocab with elements in reverse order and one new element.
225    new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
226                                       "new_vocab")
227    # New session and new graph.
228    with ops.Graph().as_default() as g:
229      with self.session(graph=g) as sess:
230        fruit_output_layer = variable_scope.get_variable(
231            "fruit_output_layer",
232            initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
233                         [0., 0., 0.]])
234        ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
235                                           current_vocab_size=3,
236                                           prev_ckpt=self.get_temp_dir(),
237                                           prev_vocab_path=prev_vocab_path,
238                                           axis=1)
239        self.evaluate(variables.global_variables_initializer())
240        self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
241                             [2.3, 2., 0.]], fruit_output_layer.eval(sess))
242
243  def testWarmStartVarWithVocabConstrainedOldVocabSize(self):
244    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
245                                        "old_vocab")
246    self._create_prev_run_var(
247        "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
248
249    # New vocab with elements in reverse order and one new element.
250    new_vocab_path = self._write_vocab(
251        ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
252    # New session and new graph.
253    with ops.Graph().as_default() as g:
254      with self.session(graph=g) as sess:
255        fruit_weights = variable_scope.get_variable(
256            "fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]])
257        ws_util._warm_start_var_with_vocab(
258            fruit_weights,
259            new_vocab_path,
260            5,
261            self.get_temp_dir(),
262            prev_vocab_path,
263            previous_vocab_size=2)
264        self.evaluate(variables.global_variables_initializer())
265        # Old vocabulary limited to ['apple', 'banana'].
266        self.assertAllClose([[0.], [0.], [1.], [0.5], [0.]],
267                            fruit_weights.eval(sess))
268
269  def testWarmStartVarWithVocabPrevVarPartitioned(self):
270    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
271                                        "old_vocab")
272    self._create_prev_run_var(
273        "fruit_weights",
274        shape=[4, 1],
275        initializer=[[0.5], [1.], [1.5], [2.]],
276        partitioner=lambda shape, dtype: [2, 1])
277
278    # New vocab with elements in reverse order and one new element.
279    new_vocab_path = self._write_vocab(
280        ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
281    # New session and new graph.
282    with ops.Graph().as_default() as g:
283      with self.session(graph=g) as sess:
284        fruit_weights = variable_scope.get_variable(
285            "fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]])
286        ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
287                                           self.get_temp_dir(), prev_vocab_path)
288        self.evaluate(variables.global_variables_initializer())
289        self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
290                            fruit_weights.eval(sess))
291
292  def testWarmStartVarWithColumnVocabPrevVarPartitioned(self):
293    prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
294    self._create_prev_run_var(
295        "fruit_output_layer",
296        shape=[4, 2],
297        initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
298        partitioner=lambda shape, dtype: [2, 1])
299
300    # New vocab with elements in reverse order and one new element.
301    new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
302                                       "new_vocab")
303    # New session and new graph.
304    with ops.Graph().as_default() as g:
305      with self.session(graph=g) as sess:
306        fruit_output_layer = variable_scope.get_variable(
307            "fruit_output_layer",
308            initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
309                         [0., 0., 0.]])
310        ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
311                                           current_vocab_size=3,
312                                           prev_ckpt=self.get_temp_dir(),
313                                           prev_vocab_path=prev_vocab_path,
314                                           axis=1)
315        self.evaluate(variables.global_variables_initializer())
316        self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
317                             [2.3, 2., 0.]], fruit_output_layer.eval(sess))
318
319  def testWarmStartVarWithVocabCurrentVarPartitioned(self):
320    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
321                                        "old_vocab")
322    self._create_prev_run_var(
323        "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
324
325    # New vocab with elements in reverse order and one new element.
326    new_vocab_path = self._write_vocab(
327        ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
328    # New session and new graph.
329    with ops.Graph().as_default() as g:
330      with self.session(graph=g) as sess:
331        fruit_weights = variable_scope.get_variable(
332            "fruit_weights",
333            shape=[6, 1],
334            initializer=[[0.], [0.], [0.], [0.], [0.], [0.]],
335            partitioner=lambda shape, dtype: [2, 1])
336        ws_util._warm_start_var_with_vocab(
337            fruit_weights,
338            new_vocab_path,
339            5,
340            self.get_temp_dir(),
341            prev_vocab_path,
342            current_oov_buckets=1)
343        self.evaluate(variables.global_variables_initializer())
344        self.assertTrue(
345            isinstance(fruit_weights, variables.PartitionedVariable))
346        fruit_weights_vars = fruit_weights._get_variable_list()
347        self.assertAllClose([[2.], [1.5], [1.]],
348                            fruit_weights_vars[0].eval(sess))
349        self.assertAllClose([[0.5], [0.], [0.]],
350                            fruit_weights_vars[1].eval(sess))
351
352  def testWarmStartVarWithColumnVocabCurrentVarPartitioned(self):
353    prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
354    self._create_prev_run_var(
355        "fruit_output_layer",
356        initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
357
358    # New vocab with elements in reverse order and one new element.
359    new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
360                                       "new_vocab")
361    # New session and new graph.
362    with ops.Graph().as_default() as g:
363      with self.session(graph=g) as sess:
364        fruit_output_layer = variable_scope.get_variable(
365            "fruit_output_layer",
366            shape=[4, 3],
367            initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
368                         [0., 0., 0.]],
369            partitioner=lambda shape, dtype: [2, 1])
370        ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
371                                           current_vocab_size=3,
372                                           prev_ckpt=self.get_temp_dir(),
373                                           prev_vocab_path=prev_vocab_path,
374                                           axis=1)
375        self.evaluate(variables.global_variables_initializer())
376        self.assertTrue(
377            isinstance(fruit_output_layer, variables.PartitionedVariable))
378        fruit_output_layer_vars = fruit_output_layer._get_variable_list()
379        self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
380                            fruit_output_layer_vars[0].eval(sess))
381        self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
382                            fruit_output_layer_vars[1].eval(sess))
383
384  def testWarmStartVarWithVocabBothVarsPartitioned(self):
385    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
386                                        "old_vocab")
387    self._create_prev_run_var(
388        "fruit_weights",
389        shape=[4, 1],
390        initializer=[[0.5], [1.], [1.5], [2.]],
391        partitioner=lambda shape, dtype: [2, 1])
392
393    # New vocab with elements in reverse order and two new elements.
394    new_vocab_path = self._write_vocab(
395        ["orange", "guava", "banana", "apple", "raspberry",
396         "blueberry"], "new_vocab")
397    # New session and new graph.
398    with ops.Graph().as_default() as g:
399      with self.session(graph=g) as sess:
400        fruit_weights = variable_scope.get_variable(
401            "fruit_weights",
402            shape=[6, 1],
403            initializer=[[0.], [0.], [0.], [0.], [0.], [0.]],
404            partitioner=lambda shape, dtype: [2, 1])
405        ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 6,
406                                           self.get_temp_dir(), prev_vocab_path)
407        self.evaluate(variables.global_variables_initializer())
408        self.assertTrue(
409            isinstance(fruit_weights, variables.PartitionedVariable))
410        fruit_weights_vars = fruit_weights._get_variable_list()
411        self.assertAllClose([[2.], [1.5], [1.]],
412                            fruit_weights_vars[0].eval(sess))
413        self.assertAllClose([[0.5], [0.], [0.]],
414                            fruit_weights_vars[1].eval(sess))
415
416  def testWarmStartVarWithColumnVocabBothVarsPartitioned(self):
417    prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
418    self._create_prev_run_var(
419        "fruit_output_layer",
420        shape=[4, 2],
421        initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
422        partitioner=lambda shape, dtype: [2, 1])
423
424    # New vocab with elements in reverse order and one new element.
425    new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
426                                       "new_vocab")
427    # New session and new graph.
428    with ops.Graph().as_default() as g:
429      with self.session(graph=g) as sess:
430        fruit_output_layer = variable_scope.get_variable(
431            "fruit_output_layer",
432            shape=[4, 3],
433            initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
434                         [0., 0., 0.]],
435            partitioner=lambda shape, dtype: [2, 1])
436        ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
437                                           current_vocab_size=3,
438                                           prev_ckpt=self.get_temp_dir(),
439                                           prev_vocab_path=prev_vocab_path,
440                                           axis=1)
441        self.evaluate(variables.global_variables_initializer())
442        self.assertTrue(
443            isinstance(fruit_output_layer, variables.PartitionedVariable))
444        fruit_output_layer_vars = fruit_output_layer._get_variable_list()
445        self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
446                            fruit_output_layer_vars[0].eval(sess))
447        self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
448                            fruit_output_layer_vars[1].eval(sess))
449
450  def testWarmStart_ListOfVariables(self):
451    # Save checkpoint from which to warm-start.
452    _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
453                                                initializer=ones())
454    # Verify we initialized the values correctly.
455    self.assertAllEqual(np.ones([10, 1]), prev_int_val)
456
457    # New graph, new session with warm-starting.
458    with ops.Graph().as_default() as g:
459      with self.session(graph=g) as sess:
460        # Initialize with zeros.
461        var = variable_scope.get_variable(
462            "v1",
463            shape=[10, 1],
464            initializer=zeros())
465        ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=[var])
466        self.evaluate(variables.global_variables_initializer())
467        # Verify weights were correctly warm-started (init overridden to ones).
468        self.assertAllEqual(var.eval(), prev_int_val)
469
470  def testWarmStart_ListOfStrings(self):
471    # Save checkpoint from which to warm-start.
472    _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
473                                                initializer=ones())
474    # Verify we initialized the values correctly.
475    self.assertAllEqual(np.ones([10, 1]), prev_int_val)
476
477    # New graph, new session with warm-starting.
478    with ops.Graph().as_default() as g:
479      with self.session(graph=g) as sess:
480        # Initialize with zeros.
481        var = variable_scope.get_variable(
482            "v1",
483            shape=[10, 1],
484            initializer=zeros())
485        ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=["v1"])
486        self.evaluate(variables.global_variables_initializer())
487        # Verify weights were correctly warm-started (init overridden to ones).
488        self.assertAllEqual(var.eval(), prev_int_val)
489
490  def testWarmStart_ListOfRegexes(self):
491    # Save checkpoint from which to warm-start.
492    [prev_v1_val, prev_v1_momentum_val,
493     prev_v2_val, _] = self._create_prev_run_vars(
494         var_names=["v1", "v1/Momentum", "v2", "v2/Momentum"],
495         shapes=[[10, 1]] * 4,
496         initializers=[ones()] * 4)
497
498    # New graph, new session with warm-starting.
499    with ops.Graph().as_default() as g:
500      with self.session(graph=g) as sess:
501        # Initialize with zeros.
502        v1 = variable_scope.get_variable(
503            "v1",
504            shape=[10, 1],
505            initializer=zeros())
506        v1_momentum = variable_scope.get_variable(
507            "v1/Momentum",
508            shape=[10, 1],
509            initializer=zeros())
510        v2 = variable_scope.get_variable(
511            "v2",
512            shape=[10, 1],
513            initializer=zeros())
514        v2_momentum = variable_scope.get_variable(
515            "v2/Momentum",
516            shape=[10, 1],
517            initializer=zeros())
518        ws_util.warm_start(self.get_temp_dir(),
519                           # This warm-starts both v1 and v1/Momentum, but only
520                           # v2 (and not v2/Momentum).
521                           vars_to_warm_start=["v1", "v2[^/]"])
522        self.evaluate(variables.global_variables_initializer())
523        # Verify the selection of weights were correctly warm-started (init
524        # overridden to ones).
525        self.assertAllEqual(v1.eval(), prev_v1_val)
526        self.assertAllEqual(v1_momentum.eval(), prev_v1_momentum_val)
527        self.assertAllEqual(v2.eval(), prev_v2_val)
528        self.assertAllEqual(v2_momentum.eval(), np.zeros([10, 1]))
529
530  def testWarmStart_SparseColumnIntegerized(self):
531    # Create feature column.
532    sc_int = fc.categorical_column_with_identity("sc_int", num_buckets=10)
533
534    # Save checkpoint from which to warm-start.
535    _, prev_int_val = self._create_prev_run_var(
536        "linear_model/sc_int/weights", shape=[10, 1], initializer=ones())
537    # Verify we initialized the values correctly.
538    self.assertAllEqual(np.ones([10, 1]), prev_int_val)
539
540    partitioner = lambda shape, dtype: [1] * len(shape)
541    # New graph, new session WITHOUT warm-starting.
542    with ops.Graph().as_default() as g:
543      with self.session(graph=g) as sess:
544        cols_to_vars = self._create_linear_model([sc_int], partitioner)
545        self.evaluate(variables.global_variables_initializer())
546        # Without warm-starting, the weights should be initialized using default
547        # initializer (which is init_ops.zeros_initializer).
548        self._assert_cols_to_vars(cols_to_vars, {sc_int: [np.zeros([10, 1])]},
549                                  sess)
550
551    # New graph, new session with warm-starting.
552    with ops.Graph().as_default() as g:
553      with self.session(graph=g) as sess:
554        cols_to_vars = self._create_linear_model([sc_int], partitioner)
555        ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=".*sc_int.*")
556        self.evaluate(variables.global_variables_initializer())
557        # Verify weights were correctly warm-started.
558        self._assert_cols_to_vars(cols_to_vars, {sc_int: [prev_int_val]}, sess)
559
560  def testWarmStart_SparseColumnHashed(self):
561    # Create feature column.
562    sc_hash = fc.categorical_column_with_hash_bucket(
563        "sc_hash", hash_bucket_size=15)
564
565    # Save checkpoint from which to warm-start.
566    _, prev_hash_val = self._create_prev_run_var(
567        "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
568
569    partitioner = lambda shape, dtype: [1] * len(shape)
570    # New graph, new session WITHOUT warm-starting.
571    with ops.Graph().as_default() as g:
572      with self.session(graph=g) as sess:
573        cols_to_vars = self._create_linear_model([sc_hash], partitioner)
574        self.evaluate(variables.global_variables_initializer())
575        # Without warm-starting, the weights should be initialized using default
576        # initializer (which is init_ops.zeros_initializer).
577        self._assert_cols_to_vars(cols_to_vars, {sc_hash: [np.zeros([15, 1])]},
578                                  sess)
579
580    # New graph, new session with warm-starting.
581    with ops.Graph().as_default() as g:
582      with self.session(graph=g) as sess:
583        cols_to_vars = self._create_linear_model([sc_hash], partitioner)
584        ws_util.warm_start(
585            self.get_temp_dir(), vars_to_warm_start=".*sc_hash.*")
586        self.evaluate(variables.global_variables_initializer())
587        # Verify weights were correctly warm-started.
588        self._assert_cols_to_vars(cols_to_vars, {sc_hash: [prev_hash_val]},
589                                  sess)
590
591  def testWarmStart_SparseColumnVocabulary(self):
592    # Create vocab for sparse column "sc_vocab".
593    vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
594                                   "vocab")
595    # Create feature column.
596    sc_vocab = fc.categorical_column_with_vocabulary_file(
597        "sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
598
599    # Save checkpoint from which to warm-start.
600    _, prev_vocab_val = self._create_prev_run_var(
601        "linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
602
603    partitioner = lambda shape, dtype: [1] * len(shape)
604    # New graph, new session WITHOUT warm-starting.
605    with ops.Graph().as_default() as g:
606      with self.session(graph=g) as sess:
607        cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
608        self.evaluate(variables.global_variables_initializer())
609        # Without warm-starting, the weights should be initialized using default
610        # initializer (which is init_ops.zeros_initializer).
611        self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]},
612                                  sess)
613
614    # New graph, new session with warm-starting.
615    with ops.Graph().as_default() as g:
616      with self.session(graph=g) as sess:
617        cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
618        # Since old vocab is not explicitly set in WarmStartSettings, the old
619        # vocab is assumed to be same as new vocab.
620        ws_util.warm_start(
621            self.get_temp_dir(), vars_to_warm_start=".*sc_vocab.*")
622        self.evaluate(variables.global_variables_initializer())
623        # Verify weights were correctly warm-started.
624        self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
625                                  sess)
626
627  def testWarmStart_ExplicitCheckpointFile(self):
628    # Create vocab for sparse column "sc_vocab".
629    vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
630                                   "vocab")
631    # Create feature column.
632    sc_vocab = fc.categorical_column_with_vocabulary_file(
633        "sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
634
635    # Save checkpoint from which to warm-start.
636    _, prev_vocab_val = self._create_prev_run_var(
637        "linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
638
639    partitioner = lambda shape, dtype: [1] * len(shape)
640    # New graph, new session WITHOUT warm-starting.
641    with ops.Graph().as_default() as g:
642      with self.session(graph=g) as sess:
643        cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
644        self.evaluate(variables.global_variables_initializer())
645        # Without warm-starting, the weights should be initialized using default
646        # initializer (which is init_ops.zeros_initializer).
647        self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]},
648                                  sess)
649
650    # New graph, new session with warm-starting.
651    with ops.Graph().as_default() as g:
652      with self.session(graph=g) as sess:
653        cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
654        # Since old vocab is not explicitly set in WarmStartSettings, the old
655        # vocab is assumed to be same as new vocab.
656        ws_util.warm_start(
657            # Explicitly provide the file prefix instead of just the dir.
658            os.path.join(self.get_temp_dir(), "model-0"),
659            vars_to_warm_start=".*sc_vocab.*")
660        self.evaluate(variables.global_variables_initializer())
661        # Verify weights were correctly warm-started.
662        self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
663                                  sess)
664
665  def testWarmStart_SparseColumnVocabularyConstrainedVocabSizes(self):
666    # Create old vocabulary, and use a size smaller than the total number of
667    # entries.
668    old_vocab_path = self._write_vocab(["apple", "guava", "banana"],
669                                       "old_vocab")
670    old_vocab_size = 2  # ['apple', 'guava']
671
672    # Create new vocab for sparse column "sc_vocab".
673    current_vocab_path = self._write_vocab(
674        ["apple", "banana", "guava", "orange"], "current_vocab")
675    # Create feature column.  Only use 2 of the actual entries, resulting in
676    # ['apple', 'banana'] for the new vocabulary.
677    sc_vocab = fc.categorical_column_with_vocabulary_file(
678        "sc_vocab", vocabulary_file=current_vocab_path, vocabulary_size=2)
679
680    # Save checkpoint from which to warm-start.
681    self._create_prev_run_var(
682        "linear_model/sc_vocab/weights", shape=[2, 1], initializer=ones())
683
684    partitioner = lambda shape, dtype: [1] * len(shape)
685    # New graph, new session WITHOUT warm-starting.
686    with ops.Graph().as_default() as g:
687      with self.session(graph=g) as sess:
688        cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
689        self.evaluate(variables.global_variables_initializer())
690        # Without warm-starting, the weights should be initialized using default
691        # initializer (which is init_ops.zeros_initializer).
692        self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([2, 1])]},
693                                  sess)
694
695    # New graph, new session with warm-starting.
696    with ops.Graph().as_default() as g:
697      with self.session(graph=g) as sess:
698        cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
699        vocab_info = ws_util.VocabInfo(
700            new_vocab=sc_vocab.vocabulary_file,
701            new_vocab_size=sc_vocab.vocabulary_size,
702            num_oov_buckets=sc_vocab.num_oov_buckets,
703            old_vocab=old_vocab_path,
704            old_vocab_size=old_vocab_size)
705        ws_util.warm_start(
706            ckpt_to_initialize_from=self.get_temp_dir(),
707            vars_to_warm_start=".*sc_vocab.*",
708            var_name_to_vocab_info={
709                "linear_model/sc_vocab/weights": vocab_info
710            })
711        self.evaluate(variables.global_variables_initializer())
712        # Verify weights were correctly warm-started.  'banana' isn't in the
713        # first two entries of the old vocabulary, so it's newly initialized.
714        self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [[[1], [0]]]}, sess)
715
716  def testWarmStart_BucketizedColumn(self):
717    # Create feature column.
718    real = fc.numeric_column("real")
719    real_bucket = fc.bucketized_column(real, boundaries=[0., 1., 2., 3.])
720
721    # Save checkpoint from which to warm-start.
722    _, prev_bucket_val = self._create_prev_run_var(
723        "linear_model/real_bucketized/weights",
724        shape=[5, 1],
725        initializer=norms())
726
727    partitioner = lambda shape, dtype: [1] * len(shape)
728    # New graph, new session WITHOUT warm-starting.
729    with ops.Graph().as_default() as g:
730      with self.session(graph=g) as sess:
731        cols_to_vars = self._create_linear_model([real_bucket], partitioner)
732        self.evaluate(variables.global_variables_initializer())
733        # Without warm-starting, the weights should be initialized using default
734        # initializer (which is init_ops.zeros_initializer).
735        self._assert_cols_to_vars(cols_to_vars,
736                                  {real_bucket: [np.zeros([5, 1])]}, sess)
737
738    # New graph, new session with warm-starting.
739    with ops.Graph().as_default() as g:
740      with self.session(graph=g) as sess:
741        cols_to_vars = self._create_linear_model([real_bucket], partitioner)
742        ws_util.warm_start(
743            self.get_temp_dir(), vars_to_warm_start=".*real_bucketized.*")
744        self.evaluate(variables.global_variables_initializer())
745        # Verify weights were correctly warm-started.
746        self._assert_cols_to_vars(cols_to_vars,
747                                  {real_bucket: [prev_bucket_val]}, sess)
748
749  def testWarmStart_MultipleCols(self):
750    # Create vocab for sparse column "sc_vocab".
751    vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
752                                   "vocab")
753
754    # Create feature columns.
755    sc_int = fc.categorical_column_with_identity("sc_int", num_buckets=10)
756    sc_hash = fc.categorical_column_with_hash_bucket(
757        "sc_hash", hash_bucket_size=15)
758    sc_keys = fc.categorical_column_with_vocabulary_list(
759        "sc_keys", vocabulary_list=["a", "b", "c", "e"])
760    sc_vocab = fc.categorical_column_with_vocabulary_file(
761        "sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
762    real = fc.numeric_column("real")
763    real_bucket = fc.bucketized_column(real, boundaries=[0., 1., 2., 3.])
764    cross = fc.crossed_column([sc_keys, sc_vocab], hash_bucket_size=20)
765    all_linear_cols = [sc_int, sc_hash, sc_keys, sc_vocab, real_bucket, cross]
766
767    # Save checkpoint from which to warm-start.  Also create a bias variable,
768    # so we can check that it's also warm-started.
769    with ops.Graph().as_default() as g:
770      with self.session(graph=g) as sess:
771        sc_int_weights = variable_scope.get_variable(
772            "linear_model/sc_int/weights", shape=[10, 1], initializer=ones())
773        sc_hash_weights = variable_scope.get_variable(
774            "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
775        sc_keys_weights = variable_scope.get_variable(
776            "linear_model/sc_keys/weights", shape=[4, 1], initializer=rand())
777        sc_vocab_weights = variable_scope.get_variable(
778            "linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
779        real_bucket_weights = variable_scope.get_variable(
780            "linear_model/real_bucketized/weights",
781            shape=[5, 1],
782            initializer=norms())
783        cross_weights = variable_scope.get_variable(
784            "linear_model/sc_keys_X_sc_vocab/weights",
785            shape=[20, 1],
786            initializer=rand())
787        bias = variable_scope.get_variable(
788            "linear_model/bias_weights",
789            shape=[1],
790            initializer=rand())
791        self._write_checkpoint(sess)
792        (prev_int_val, prev_hash_val, prev_keys_val, prev_vocab_val,
793         prev_bucket_val, prev_cross_val, prev_bias_val) = sess.run([
794             sc_int_weights, sc_hash_weights, sc_keys_weights, sc_vocab_weights,
795             real_bucket_weights, cross_weights, bias
796         ])
797
798    partitioner = lambda shape, dtype: [1] * len(shape)
799    # New graph, new session WITHOUT warm-starting.
800    with ops.Graph().as_default() as g:
801      with self.session(graph=g) as sess:
802        cols_to_vars = self._create_linear_model(all_linear_cols, partitioner)
803        self.evaluate(variables.global_variables_initializer())
804        # Without warm-starting, all weights should be initialized using default
805        # initializer (which is init_ops.zeros_initializer).
806        self._assert_cols_to_vars(cols_to_vars, {
807            sc_int: [np.zeros([10, 1])],
808            sc_hash: [np.zeros([15, 1])],
809            sc_keys: [np.zeros([4, 1])],
810            sc_vocab: [np.zeros([4, 1])],
811            real_bucket: [np.zeros([5, 1])],
812            cross: [np.zeros([20, 1])],
813        }, sess)
814
815    # New graph, new session with warm-starting.
816    with ops.Graph().as_default() as g:
817      with self.session(graph=g) as sess:
818        cols_to_vars = self._create_linear_model(all_linear_cols, partitioner)
819        vocab_info = ws_util.VocabInfo(
820            new_vocab=sc_vocab.vocabulary_file,
821            new_vocab_size=sc_vocab.vocabulary_size,
822            num_oov_buckets=sc_vocab.num_oov_buckets,
823            old_vocab=vocab_path)
824        ws_util.warm_start(
825            self.get_temp_dir(),
826            var_name_to_vocab_info={
827                "linear_model/sc_vocab/weights": vocab_info
828            })
829        self.evaluate(variables.global_variables_initializer())
830        # Verify weights were correctly warm-started.
831        self._assert_cols_to_vars(cols_to_vars, {
832            sc_int: [prev_int_val],
833            sc_hash: [prev_hash_val],
834            sc_keys: [prev_keys_val],
835            sc_vocab: [prev_vocab_val],
836            real_bucket: [prev_bucket_val],
837            cross: [prev_cross_val],
838            "bias": [prev_bias_val],
839        }, sess)
840
841  def testWarmStartMoreSettings(self):
842    # Create old and new vocabs for sparse column "sc_vocab".
843    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
844                                        "old_vocab")
845    new_vocab_path = self._write_vocab(
846        ["orange", "guava", "banana", "apple", "raspberry",
847         "blueberry"], "new_vocab")
848    # Create feature columns.
849    sc_hash = fc.categorical_column_with_hash_bucket(
850        "sc_hash", hash_bucket_size=15)
851    sc_keys = fc.categorical_column_with_vocabulary_list(
852        "sc_keys", vocabulary_list=["a", "b", "c", "e"])
853    sc_vocab = fc.categorical_column_with_vocabulary_file(
854        "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
855    all_linear_cols = [sc_hash, sc_keys, sc_vocab]
856
857    # Save checkpoint from which to warm-start.
858    with ops.Graph().as_default() as g:
859      with self.session(graph=g) as sess:
860        variable_scope.get_variable(
861            "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
862        sc_keys_weights = variable_scope.get_variable(
863            "some_other_name", shape=[4, 1], initializer=rand())
864        variable_scope.get_variable(
865            "linear_model/sc_vocab/weights",
866            initializer=[[0.5], [1.], [2.], [3.]])
867        self._write_checkpoint(sess)
868        prev_keys_val = self.evaluate(sc_keys_weights)
869
870    def _partitioner(shape, dtype):  # pylint:disable=unused-argument
871      # Partition each var into 2 equal slices.
872      partitions = [1] * len(shape)
873      partitions[0] = min(2, shape.dims[0].value)
874      return partitions
875
876    # New graph, new session with warm-starting.
877    with ops.Graph().as_default() as g:
878      with self.session(graph=g) as sess:
879        cols_to_vars = self._create_linear_model(all_linear_cols, _partitioner)
880        vocab_info = ws_util.VocabInfo(
881            new_vocab=sc_vocab.vocabulary_file,
882            new_vocab_size=sc_vocab.vocabulary_size,
883            num_oov_buckets=sc_vocab.num_oov_buckets,
884            old_vocab=prev_vocab_path)
885        ws_util.warm_start(
886            self.get_temp_dir(),
887            vars_to_warm_start=".*(sc_keys|sc_vocab).*",
888            var_name_to_vocab_info={
889                ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info
890            },
891            var_name_to_prev_var_name={
892                ws_util._infer_var_name(cols_to_vars[sc_keys]):
893                    "some_other_name"
894            })
895        self.evaluate(variables.global_variables_initializer())
896        # Verify weights were correctly warm-started.  Var corresponding to
897        # sc_hash should not be warm-started.  Var corresponding to sc_vocab
898        # should be correctly warm-started after vocab remapping.
899        self._assert_cols_to_vars(cols_to_vars, {
900            sc_keys:
901                np.split(prev_keys_val, 2),
902            sc_hash: [np.zeros([8, 1]), np.zeros([7, 1])],
903            sc_vocab: [
904                np.array([[3.], [2.], [1.]]),
905                np.array([[0.5], [0.], [0.]])
906            ]
907        }, sess)
908
909  def testWarmStartMoreSettingsNoPartitioning(self):
910    # Create old and new vocabs for sparse column "sc_vocab".
911    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
912                                        "old_vocab")
913    new_vocab_path = self._write_vocab(
914        ["orange", "guava", "banana", "apple", "raspberry",
915         "blueberry"], "new_vocab")
916    # Create feature columns.
917    sc_hash = fc.categorical_column_with_hash_bucket(
918        "sc_hash", hash_bucket_size=15)
919    sc_keys = fc.categorical_column_with_vocabulary_list(
920        "sc_keys", vocabulary_list=["a", "b", "c", "e"])
921    sc_vocab = fc.categorical_column_with_vocabulary_file(
922        "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
923    all_linear_cols = [sc_hash, sc_keys, sc_vocab]
924
925    # Save checkpoint from which to warm-start.
926    with ops.Graph().as_default() as g:
927      with self.session(graph=g) as sess:
928        variable_scope.get_variable(
929            "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
930        sc_keys_weights = variable_scope.get_variable(
931            "some_other_name", shape=[4, 1], initializer=rand())
932        variable_scope.get_variable(
933            "linear_model/sc_vocab/weights",
934            initializer=[[0.5], [1.], [2.], [3.]])
935        self._write_checkpoint(sess)
936        prev_keys_val = self.evaluate(sc_keys_weights)
937
938    # New graph, new session with warm-starting.
939    with ops.Graph().as_default() as g:
940      with self.session(graph=g) as sess:
941        cols_to_vars = self._create_linear_model(all_linear_cols,
942                                                 partitioner=None)
943        vocab_info = ws_util.VocabInfo(
944            new_vocab=sc_vocab.vocabulary_file,
945            new_vocab_size=sc_vocab.vocabulary_size,
946            num_oov_buckets=sc_vocab.num_oov_buckets,
947            old_vocab=prev_vocab_path)
948        ws_util.warm_start(
949            self.get_temp_dir(),
950            vars_to_warm_start=".*(sc_keys|sc_vocab).*",
951            var_name_to_vocab_info={
952                ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info
953            },
954            var_name_to_prev_var_name={
955                ws_util._infer_var_name(cols_to_vars[sc_keys]):
956                    "some_other_name"
957            })
958        self.evaluate(variables.global_variables_initializer())
959        # Verify weights were correctly warm-started.  Var corresponding to
960        # sc_hash should not be warm-started.  Var corresponding to sc_vocab
961        # should be correctly warm-started after vocab remapping.
962        self._assert_cols_to_vars(cols_to_vars, {
963            sc_keys: [prev_keys_val],
964            sc_hash: [np.zeros([15, 1])],
965            sc_vocab: [np.array([[3.], [2.], [1.], [0.5], [0.], [0.]])]
966        }, sess)
967
968  def testWarmStartVarsToWarmstartIsNone(self):
969    # Create old and new vocabs for sparse column "sc_vocab".
970    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
971                                        "old_vocab")
972    new_vocab_path = self._write_vocab(
973        ["orange", "guava", "banana", "apple", "raspberry",
974         "blueberry"], "new_vocab")
975    # Create feature columns.
976    sc_hash = fc.categorical_column_with_hash_bucket(
977        "sc_hash", hash_bucket_size=15)
978    sc_keys = fc.categorical_column_with_vocabulary_list(
979        "sc_keys", vocabulary_list=["a", "b", "c", "e"])
980    sc_vocab = fc.categorical_column_with_vocabulary_file(
981        "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
982    all_linear_cols = [sc_hash, sc_keys, sc_vocab]
983
984    # Save checkpoint from which to warm-start.
985    with ops.Graph().as_default() as g:
986      with self.session(graph=g) as sess:
987        variable_scope.get_variable(
988            "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
989        variable_scope.get_variable(
990            "some_other_name", shape=[4, 1], initializer=rand())
991        variable_scope.get_variable(
992            "linear_model/sc_vocab/weights",
993            initializer=[[0.5], [1.], [2.], [3.]])
994        self._write_checkpoint(sess)
995
996    def _partitioner(shape, dtype):  # pylint:disable=unused-argument
997      # Partition each var into 2 equal slices.
998      partitions = [1] * len(shape)
999      partitions[0] = min(2, shape.dims[0].value)
1000      return partitions
1001
1002    # New graph, new session with warm-starting.
1003    with ops.Graph().as_default() as g:
1004      with self.session(graph=g) as sess:
1005        cols_to_vars = self._create_linear_model(all_linear_cols, _partitioner)
1006        vocab_info = ws_util.VocabInfo(
1007            new_vocab=sc_vocab.vocabulary_file,
1008            new_vocab_size=sc_vocab.vocabulary_size,
1009            num_oov_buckets=sc_vocab.num_oov_buckets,
1010            old_vocab=prev_vocab_path)
1011        ws_util.warm_start(
1012            self.get_temp_dir(),
1013            # The special value of None here will ensure that only the variable
1014            # specified in var_name_to_vocab_info (sc_vocab embedding) is
1015            # warm-started.
1016            vars_to_warm_start=None,
1017            var_name_to_vocab_info={
1018                ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info
1019            },
1020            # Even though this is provided, the None value for
1021            # vars_to_warm_start overrides the logic, and this will not be
1022            # warm-started.
1023            var_name_to_prev_var_name={
1024                ws_util._infer_var_name(cols_to_vars[sc_keys]):
1025                    "some_other_name"
1026            })
1027        self.evaluate(variables.global_variables_initializer())
1028        # Verify weights were correctly warm-started.  Var corresponding to
1029        # sc_vocab should be correctly warm-started after vocab remapping,
1030        # and neither of the other two should be warm-started..
1031        self._assert_cols_to_vars(cols_to_vars, {
1032            sc_keys: [np.zeros([2, 1]), np.zeros([2, 1])],
1033            sc_hash: [np.zeros([8, 1]), np.zeros([7, 1])],
1034            sc_vocab: [
1035                np.array([[3.], [2.], [1.]]),
1036                np.array([[0.5], [0.], [0.]])
1037            ]
1038        }, sess)
1039
1040  def testWarmStartEmbeddingColumn(self):
1041    # Create old and new vocabs for embedding column "sc_vocab".
1042    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
1043                                        "old_vocab")
1044    new_vocab_path = self._write_vocab(
1045        ["orange", "guava", "banana", "apple", "raspberry", "blueberry"],
1046        "new_vocab")
1047
1048    # Save checkpoint from which to warm-start.
1049    with ops.Graph().as_default() as g:
1050      with self.session(graph=g) as sess:
1051        variable_scope.get_variable(
1052            "input_layer/sc_vocab_embedding/embedding_weights",
1053            initializer=[[0.5, 0.4], [1., 1.1], [2., 2.2], [3., 3.3]])
1054        self._write_checkpoint(sess)
1055
1056    def _partitioner(shape, dtype):  # pylint:disable=unused-argument
1057      # Partition each var into 2 equal slices.
1058      partitions = [1] * len(shape)
1059      partitions[0] = min(2, shape.dims[0].value)
1060      return partitions
1061
1062    # Create feature columns.
1063    sc_vocab = fc.categorical_column_with_vocabulary_file(
1064        "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
1065    emb_vocab_column = fc.embedding_column(
1066        categorical_column=sc_vocab,
1067        dimension=2)
1068    all_deep_cols = [emb_vocab_column]
1069    # New graph, new session with warm-starting.
1070    with ops.Graph().as_default() as g:
1071      with self.session(graph=g) as sess:
1072        cols_to_vars = {}
1073        with variable_scope.variable_scope("", partitioner=_partitioner):
1074          # Create the variables.
1075          fc.input_layer(
1076              features=self._create_dummy_inputs(),
1077              feature_columns=all_deep_cols,
1078              cols_to_vars=cols_to_vars)
1079        vocab_info = ws_util.VocabInfo(
1080            new_vocab=sc_vocab.vocabulary_file,
1081            new_vocab_size=sc_vocab.vocabulary_size,
1082            num_oov_buckets=sc_vocab.num_oov_buckets,
1083            old_vocab=prev_vocab_path,
1084            # Can't use constant_initializer with load_and_remap.  In practice,
1085            # use a truncated normal initializer.
1086            backup_initializer=init_ops.random_uniform_initializer(
1087                minval=0.42, maxval=0.42))
1088        ws_util.warm_start(
1089            self.get_temp_dir(),
1090            var_name_to_vocab_info={
1091                ws_util._infer_var_name(cols_to_vars[emb_vocab_column]):
1092                    vocab_info
1093            })
1094        self.evaluate(variables.global_variables_initializer())
1095        # Verify weights were correctly warm-started. Var corresponding to
1096        # emb_vocab_column should be correctly warm-started after vocab
1097        # remapping. Missing values are filled in with the EmbeddingColumn's
1098        # initializer.
1099        self._assert_cols_to_vars(
1100            cols_to_vars, {
1101                emb_vocab_column: [
1102                    np.array([[3., 3.3], [2., 2.2], [1., 1.1]]),
1103                    np.array([[0.5, 0.4], [0.42, 0.42], [0.42, 0.42]])
1104                ]
1105            }, sess)
1106
1107  def testWarmStartEmbeddingColumnLinearModel(self):
1108    # Create old and new vocabs for embedding column "sc_vocab".
1109    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
1110                                        "old_vocab")
1111    new_vocab_path = self._write_vocab(
1112        ["orange", "guava", "banana", "apple", "raspberry", "blueberry"],
1113        "new_vocab")
1114
1115    # Save checkpoint from which to warm-start.
1116    with ops.Graph().as_default() as g:
1117      with self.session(graph=g) as sess:
1118        variable_scope.get_variable(
1119            "linear_model/sc_vocab_embedding/embedding_weights",
1120            initializer=[[0.5, 0.4], [1., 1.1], [2., 2.2], [3., 3.3]])
1121        variable_scope.get_variable(
1122            "linear_model/sc_vocab_embedding/weights",
1123            initializer=[[0.69], [0.71]])
1124        self._write_checkpoint(sess)
1125
1126    def _partitioner(shape, dtype):  # pylint:disable=unused-argument
1127      # Partition each var into 2 equal slices.
1128      partitions = [1] * len(shape)
1129      partitions[0] = min(2, shape.dims[0].value)
1130      return partitions
1131
1132    # Create feature columns.
1133    sc_vocab = fc.categorical_column_with_vocabulary_file(
1134        "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
1135    emb_vocab = fc.embedding_column(
1136        categorical_column=sc_vocab,
1137        dimension=2)
1138    all_deep_cols = [emb_vocab]
1139    # New graph, new session with warm-starting.
1140    with ops.Graph().as_default() as g:
1141      with self.session(graph=g) as sess:
1142        cols_to_vars = {}
1143        with variable_scope.variable_scope("", partitioner=_partitioner):
1144          # Create the variables.
1145          fc.linear_model(
1146              features=self._create_dummy_inputs(),
1147              feature_columns=all_deep_cols,
1148              cols_to_vars=cols_to_vars)
1149
1150        # Construct the vocab_info for the embedding weight.
1151        vocab_info = ws_util.VocabInfo(
1152            new_vocab=sc_vocab.vocabulary_file,
1153            new_vocab_size=sc_vocab.vocabulary_size,
1154            num_oov_buckets=sc_vocab.num_oov_buckets,
1155            old_vocab=prev_vocab_path,
1156            # Can't use constant_initializer with load_and_remap.  In practice,
1157            # use a truncated normal initializer.
1158            backup_initializer=init_ops.random_uniform_initializer(
1159                minval=0.42, maxval=0.42))
1160        ws_util.warm_start(
1161            self.get_temp_dir(),
1162            vars_to_warm_start=".*sc_vocab.*",
1163            var_name_to_vocab_info={
1164                "linear_model/sc_vocab_embedding/embedding_weights": vocab_info
1165            })
1166        self.evaluate(variables.global_variables_initializer())
1167        # Verify weights were correctly warm-started. Var corresponding to
1168        # emb_vocab should be correctly warm-started after vocab remapping.
1169        # Missing values are filled in with the EmbeddingColumn's initializer.
1170        self._assert_cols_to_vars(
1171            cols_to_vars,
1172            {
1173                emb_vocab: [
1174                    # linear weights part 0.
1175                    np.array([[0.69]]),
1176                    # linear weights part 1.
1177                    np.array([[0.71]]),
1178                    # embedding_weights part 0.
1179                    np.array([[3., 3.3], [2., 2.2], [1., 1.1]]),
1180                    # embedding_weights part 1.
1181                    np.array([[0.5, 0.4], [0.42, 0.42], [0.42, 0.42]])
1182                ]
1183            },
1184            sess)
1185
1186  def testErrorConditions(self):
1187    x = variable_scope.get_variable(
1188        "x",
1189        shape=[4, 1],
1190        initializer=ones(),
1191        partitioner=lambda shape, dtype: [2, 1])
1192
1193    # List of PartitionedVariable is invalid type when warm-starting with vocab.
1194    self.assertRaises(TypeError, ws_util._warm_start_var_with_vocab, [x],
1195                      "/tmp", 5, "/tmp", "/tmp")
1196
1197    # Unused variable names raises ValueError.
1198    with ops.Graph().as_default():
1199      with self.cached_session() as sess:
1200        x = variable_scope.get_variable(
1201            "x",
1202            shape=[4, 1],
1203            initializer=ones(),
1204            partitioner=lambda shape, dtype: [2, 1])
1205        self._write_checkpoint(sess)
1206
1207    self.assertRaises(
1208        ValueError,
1209        ws_util.warm_start,
1210        self.get_temp_dir(),
1211        var_name_to_vocab_info={"y": ws_util.VocabInfo("", 1, 0, "")})
1212    self.assertRaises(
1213        ValueError,
1214        ws_util.warm_start,
1215        self.get_temp_dir(),
1216        var_name_to_prev_var_name={"y": "y2"})
1217
1218
1219if __name__ == "__main__":
1220  test.main()
1221