• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for the ops to generate and execute vocab remapping."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21
22import numpy as np
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import gen_checkpoint_ops
30from tensorflow.python.ops import partitioned_variables
31from tensorflow.python.ops import variable_scope
32from tensorflow.python.ops import variables
33from tensorflow.python.platform import flags
34from tensorflow.python.platform import test
35from tensorflow.python.training import saver
36
37FLAGS = flags.FLAGS
38
39
40class GenerateVocabRemappingTest(test.TestCase):
41  """Tests for the generate_vocab_remapping() method."""
42
43  def setUp(self):
44    self.new_vocab_file = os.path.join(self.get_temp_dir(),
45                                       'keyword_shifted.txt')
46    with open(self.new_vocab_file, 'w') as f:
47      f.write('\n'.join(['MISSING', 'knitting', 'eminem']) + '\n')
48    self.old_vocab_file = os.path.join(self.get_temp_dir(),
49                                       'keyword.txt')
50    with open(self.old_vocab_file, 'w') as f:
51      f.write('\n'.join(['knitting', 'eminem', 'MISSING']) + '\n')
52
53  @test_util.run_deprecated_v1
54  def test_generate_remapping_with_no_vocab_changes(self):
55    """Tests where vocab does not change at all."""
56    remapping, num_present = gen_checkpoint_ops.generate_vocab_remapping(
57        new_vocab_file=self.old_vocab_file,
58        old_vocab_file=self.old_vocab_file,
59        num_new_vocab=3,
60        new_vocab_offset=0)
61    expected_remapping = range(0, 3)
62    expected_num_present = 3
63    with self.cached_session():
64      self.assertAllEqual(expected_remapping, self.evaluate(remapping))
65      self.assertAllEqual(expected_num_present, self.evaluate(num_present))
66
67  def test_generate_remapping_with_shifted_vocab(self):
68    """Tests where vocab is the same, but shifted / ordered differently."""
69    remapping, num_present = gen_checkpoint_ops.generate_vocab_remapping(
70        new_vocab_file=self.new_vocab_file,
71        old_vocab_file=self.old_vocab_file,
72        num_new_vocab=3,
73        new_vocab_offset=0)
74    expected_remapping = [2, 0, 1]
75    expected_num_present = 3
76    with self.cached_session():
77      self.assertAllEqual(expected_remapping, self.evaluate(remapping))
78      self.assertAllEqual(expected_num_present, self.evaluate(num_present))
79
80  def test_generate_remapping_with_offset(self):
81    """Tests offset and num_new_vocab logic."""
82    remapping, num_present = gen_checkpoint_ops.generate_vocab_remapping(
83        new_vocab_file=self.new_vocab_file,
84        old_vocab_file=self.old_vocab_file,
85        num_new_vocab=1,
86        new_vocab_offset=1)
87    expected_remapping = [0]
88    expected_num_present = 1
89    with self.cached_session():
90      self.assertAllEqual(expected_remapping, self.evaluate(remapping))
91      self.assertAllEqual(expected_num_present, self.evaluate(num_present))
92
93  def test_generate_remapping_with_old_vocab_size(self):
94    """Tests where old_vocab_size is specified."""
95    remapping, num_present = gen_checkpoint_ops.generate_vocab_remapping(
96        new_vocab_file=self.new_vocab_file,
97        old_vocab_file=self.old_vocab_file,
98        num_new_vocab=3,
99        new_vocab_offset=0,
100        # Old vocabulary becomes ['knitting', 'eminem'].
101        old_vocab_size=2)
102    expected_remapping = [-1, 0, 1]
103    expected_num_present = 2
104    with self.cached_session():
105      self.assertAllEqual(expected_remapping, self.evaluate(remapping))
106      self.assertAllEqual(expected_num_present, self.evaluate(num_present))
107
108
109class LoadAndRemapMatrixTest(test.TestCase):
110  """Tests for the load_and_remap_matrix() op."""
111
112  def setUp(self):
113    ops.reset_default_graph()
114    self.old_num_rows = 5
115    self.old_num_cols = 16
116    self.matrix_value = np.reshape(
117        range(0, self.old_num_rows * self.old_num_cols), (self.old_num_rows,
118                                                          self.old_num_cols))
119    with variable_scope.variable_scope('some_scope'):
120      matrix = variable_scope.get_variable(
121          'matrix',
122          dtype=dtypes.float32,
123          initializer=constant_op.constant(
124              self.matrix_value, dtype=dtypes.float32))
125      self.old_tensor_name = 'some_scope/matrix'
126
127    save = saver.Saver([matrix])
128    with self.cached_session() as sess:
129      self.evaluate(variables.global_variables_initializer())
130      self.bundle_file = os.path.join(test.get_temp_dir(), 'bundle_checkpoint')
131      save.save(sess, self.bundle_file)
132
133  def test_load_and_remap_no_missing(self):
134    """Tests the op's load and remap where there are no missing entries."""
135
136    # No column remapping, new weight matrix has second row, then first row.
137    row_remapping = [1, 0]
138    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
139        ckpt_path=[self.bundle_file],
140        old_tensor_name=self.old_tensor_name,
141        row_remapping=row_remapping,
142        col_remapping=[],
143        initializing_values=[],
144        num_rows=2,
145        num_cols=self.old_num_cols)
146    with self.cached_session():
147      self.assertAllClose(self.matrix_value[row_remapping],
148                          self.evaluate(remapped_matrix))
149
150    # No row remapping, new weight matrix has third col, then first col.
151    row_remapping = list(range(self.old_num_rows))
152    col_remapping = [2, 0]
153    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
154        ckpt_path=[self.bundle_file],
155        old_tensor_name=self.old_tensor_name,
156        row_remapping=row_remapping,
157        col_remapping=col_remapping,
158        initializing_values=[],
159        num_rows=len(row_remapping),
160        num_cols=len(col_remapping))
161    with self.cached_session():
162      self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
163                          self.evaluate(remapped_matrix))
164
165    # Both row and column remappings.
166    row_remapping = [1, 0, 4]
167    col_remapping = [1, 15]
168    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
169        ckpt_path=[self.bundle_file],
170        old_tensor_name=self.old_tensor_name,
171        row_remapping=row_remapping,
172        col_remapping=col_remapping,
173        initializing_values=[],
174        num_rows=len(row_remapping),
175        num_cols=len(col_remapping))
176    with self.cached_session():
177      self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
178                          self.evaluate(remapped_matrix))
179
180  def test_load_and_remap_with_init(self):
181    """Tests the op's load and remap where there are missing entries."""
182    init_val = 42
183    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
184        ckpt_path=[self.bundle_file],
185        old_tensor_name=self.old_tensor_name,
186        row_remapping=[2, -1, 0],
187        col_remapping=[1, -1],
188        initializing_values=[init_val] * 4,
189        num_rows=3,
190        num_cols=2)
191
192    expected_remapped_matrix = np.reshape(
193        [33, init_val, init_val, init_val, 1, init_val], [3, 2])
194
195    with self.cached_session():
196      self.assertAllClose(expected_remapped_matrix,
197                          self.evaluate(remapped_matrix))
198
199  def test_load_and_remap_all_missing_rows(self):
200    """Tests when all the rows are missing and need to be initialized."""
201    num_rows = 7
202    initializing_values = [42] * num_rows * self.old_num_cols
203    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
204        ckpt_path=[self.bundle_file],
205        old_tensor_name=self.old_tensor_name,
206        row_remapping=[-1] * num_rows,
207        col_remapping=[],
208        initializing_values=initializing_values,
209        num_rows=num_rows,
210        num_cols=self.old_num_cols)
211    with self.cached_session():
212      self.assertAllClose(
213          np.reshape(initializing_values, (num_rows, self.old_num_cols)),
214          self.evaluate(remapped_matrix))
215
216  def test_load_and_remap_all_missing_rows_and_cols(self):
217    """Tests when all the rows & cols are missing and need to be initialized."""
218    num_rows = 7
219    num_cols = 4
220    initializing_values = [42] * num_rows * num_cols
221    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
222        ckpt_path=[self.bundle_file],
223        old_tensor_name=self.old_tensor_name,
224        row_remapping=[-1] * num_rows,
225        col_remapping=[-1] * num_cols,
226        initializing_values=initializing_values,
227        num_rows=num_rows,
228        num_cols=num_cols)
229    with self.cached_session():
230      self.assertAllClose(
231          np.reshape(initializing_values, (num_rows, num_cols)),
232          self.evaluate(remapped_matrix))
233
234  @test_util.run_deprecated_v1
235  def test_load_and_remap_invalid_remapping(self):
236    """Tests that errors are raised when an ID maps to multiple new IDs.
237
238    (This should usually not happen when using public APIs).
239    """
240    invalid_remapping = [1, 0, 0, 0, 1, 2]
241
242    # Invalid row remapping.
243    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
244        ckpt_path=[self.bundle_file],
245        old_tensor_name=self.old_tensor_name,
246        row_remapping=invalid_remapping,
247        col_remapping=[],
248        initializing_values=[],
249        num_rows=len(invalid_remapping),
250        num_cols=self.old_num_cols)
251    with self.cached_session(), self.assertRaises(errors.UnimplementedError):
252      self.evaluate(remapped_matrix)
253
254    # Invalid column remapping.
255    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
256        ckpt_path=[self.bundle_file],
257        old_tensor_name=self.old_tensor_name,
258        row_remapping=list(range(self.old_num_rows)),
259        col_remapping=invalid_remapping,
260        initializing_values=[],
261        num_rows=self.old_num_rows,
262        num_cols=len(invalid_remapping))
263    with self.cached_session(), self.assertRaises(errors.UnimplementedError):
264      self.evaluate(remapped_matrix)
265
266  @test_util.run_deprecated_v1
267  def test_load_and_remap_incorrect_initializing_values(self):
268    """Tests that errors are raised with incorrect number of init values."""
269    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
270        ckpt_path=[self.bundle_file],
271        old_tensor_name=self.old_tensor_name,
272        row_remapping=[2, -1, 0],
273        col_remapping=[1, -1],
274        # Too few initializing values - there should be 4. For some reason,
275        # initializing_values must contain no element (instead of 3 or fewer) to
276        # ensure that a seg fault would reliably occur if the check raising the
277        # InvalidArgumentError were not present.
278        initializing_values=[],
279        num_rows=3,
280        num_cols=2)
281    with self.cached_session(), self.assertRaises(errors.InvalidArgumentError):
282      self.evaluate(remapped_matrix)
283
284    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
285        ckpt_path=[self.bundle_file],
286        old_tensor_name=self.old_tensor_name,
287        row_remapping=[2, -1, 0],
288        col_remapping=[1, -1],
289        # Too many initializing values - there should be 4.
290        initializing_values=[0] * 5,
291        num_rows=3,
292        num_cols=2)
293    with self.cached_session(), self.assertRaises(errors.InvalidArgumentError):
294      self.evaluate(remapped_matrix)
295
296
297class LoadAndRemapMatrixWithMaxRowsTest(test.TestCase):
298  """Tests for the load_and_remap_matrix() op.
299
300  (Specifically focused on the max_rows_in_memory arg and its effects on
301  TensorBundle's BundleReader and TensorSlice logic).
302  """
303
304  def _test_loading_variable_with_max_rows(self, np_value, partitioner,
305                                           max_rows_in_memory):
306    """Helper function for various tests using max_rows_in_memory."""
307    ops.reset_default_graph()
308    old_tensor_name = 'matrix_to_load_and_remap'
309    matrix = variable_scope.get_variable(
310        old_tensor_name,
311        dtype=dtypes.float32,
312        initializer=constant_op.constant(np_value, dtype=dtypes.float32),
313        partitioner=partitioner)
314
315    with self.cached_session() as sess:
316      ckpt_path = os.path.join(test.get_temp_dir(), 'temp_ckpt')
317      save = saver.Saver([matrix])
318      self.evaluate(variables.global_variables_initializer())
319      save.save(sess, ckpt_path)
320      num_rows, num_cols = np_value.shape
321
322      # Tests loading the entire tensor (except reversed).
323      remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
324          ckpt_path=ckpt_path,
325          old_tensor_name=old_tensor_name,
326          # Simply reverses the rows of the matrix.
327          row_remapping=list(range(num_rows - 1, -1, -1)),
328          col_remapping=[],
329          initializing_values=[],
330          num_rows=num_rows,
331          num_cols=num_cols,
332          max_rows_in_memory=max_rows_in_memory)
333      self.assertAllClose(np_value[::-1], self.evaluate(remapped_matrix))
334
335      # Tests loading the tensor (except for the first and last rows), with
336      # uninitialized values. Requires num_rows to be at least 3 since we're
337      # skipping the first and last rows.
338      self.assertGreater(num_rows, 2)
339      prefix_rows = 2
340      suffix_rows = 3
341      remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
342          ckpt_path=ckpt_path,
343          old_tensor_name=old_tensor_name,
344          # Reverses the rows of the matrix, then prepends and appends
345          # uninitialized rows.
346          row_remapping=([-1] * prefix_rows + list(range(1, num_rows - 1)) +
347                         [-1] * suffix_rows),
348          col_remapping=[],
349          initializing_values=[42] * (prefix_rows + suffix_rows) * num_cols,
350          num_rows=num_rows - 2 + prefix_rows + suffix_rows,
351          num_cols=num_cols,
352          max_rows_in_memory=max_rows_in_memory)
353      self.assertAllClose(
354          np.vstack([
355              np.tile(42, [prefix_rows, num_cols]), np_value[1:-1],
356              np.tile(42, [suffix_rows, num_cols])
357          ]), self.evaluate(remapped_matrix))
358
359      # Tests when everything is taken from initializing_values.
360      new_rows = 7
361      initializing_values = [42] * new_rows * num_cols
362      remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
363          ckpt_path=ckpt_path,
364          old_tensor_name=old_tensor_name,
365          # Nothing is loaded from the old tensor.
366          row_remapping=[-1] * new_rows,
367          col_remapping=[],
368          initializing_values=initializing_values,
369          num_rows=new_rows,
370          num_cols=num_cols,
371          max_rows_in_memory=max_rows_in_memory)
372      self.assertAllClose(
373          np.reshape(initializing_values, (new_rows, num_cols)),
374          self.evaluate(remapped_matrix))
375
376  @test_util.run_deprecated_v1
377  def test_loading_rows_divisible_by_max_rows(self):
378    """Tests loading normal var when rows are evenly divisible by max_rows."""
379    self._test_loading_variable_with_max_rows(
380        np_value=np.reshape(list(range(0, 36)), (9, 4)),
381        partitioner=None,
382        # 9 is evenly divisible by 3.
383        max_rows_in_memory=3)
384
385  @test_util.run_deprecated_v1
386  def test_loading_rows_not_divisible_by_max_rows(self):
387    """Tests loading normal var when rows aren't divisible by max_rows."""
388    self._test_loading_variable_with_max_rows(
389        np_value=np.reshape(list(range(0, 36)), (9, 4)),
390        partitioner=None,
391        # 9 is not evenly divisible by 4.
392        max_rows_in_memory=4)
393
394  @test_util.run_deprecated_v1
395  def test_loading_rows_less_than_max_rows(self):
396    """Tests loading normal var as a single slice.
397
398    (When the specified max_rows_in_memory is larger than the number of rows)
399    """
400    self._test_loading_variable_with_max_rows(
401        np_value=np.reshape(list(range(0, 36)), (9, 4)),
402        partitioner=None,
403        # 10 > 9.
404        max_rows_in_memory=10)
405
406  @test_util.run_deprecated_v1
407  def test_loading_no_max_rows(self):
408    """Tests loading normal var as a single slice with no valid max_rows."""
409    self._test_loading_variable_with_max_rows(
410        np_value=np.reshape(list(range(0, 18)), (6, 3)),
411        partitioner=None,
412        max_rows_in_memory=-1)
413
414  @test_util.run_deprecated_v1
415  def test_loading_partitions_equals_max_rows(self):
416    """Tests loading partitioned var sliced on partition boundary."""
417    self._test_loading_variable_with_max_rows(
418        np_value=np.reshape(list(range(0, 36)), (9, 4)),
419        partitioner=partitioned_variables.fixed_size_partitioner(3),
420        # With a tensor of shape [9, 3] and 3 partitions, each partition has
421        # exactly 3 rows.
422        max_rows_in_memory=3)
423
424  @test_util.run_deprecated_v1
425  def test_loading_partitions_greater_than_max_rows(self):
426    """Tests loading partitioned var with more slices than partitions."""
427    self._test_loading_variable_with_max_rows(
428        np_value=np.reshape(list(range(0, 36)), (9, 4)),
429        partitioner=partitioned_variables.fixed_size_partitioner(3),
430        # Even though each partition has 3 rows, we'll only load the tensor one
431        # row at a time.
432        max_rows_in_memory=1)
433
434  @test_util.run_deprecated_v1
435  def test_loading_partitions_less_than_max_rows(self):
436    """Tests loading partitioned var as a single slice.
437
438    (When the specified max_rows_in_memory is larger than the number of rows)
439    """
440    self._test_loading_variable_with_max_rows(
441        np_value=np.reshape(list(range(0, 36)), (9, 4)),
442        partitioner=partitioned_variables.fixed_size_partitioner(3),
443        max_rows_in_memory=10)
444
445  @test_util.run_deprecated_v1
446  def test_loading_partitions_no_max_rows(self):
447    """Tests loading partitioned var as single slice with no valid max_rows."""
448    self._test_loading_variable_with_max_rows(
449        np_value=np.reshape(list(range(0, 36)), (9, 4)),
450        partitioner=partitioned_variables.fixed_size_partitioner(3),
451        max_rows_in_memory=-1)
452
453
454if __name__ == '__main__':
455  test.main()
456