1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for the ops to generate and execute vocab remapping.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import os 21 22import numpy as np 23 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import gen_checkpoint_ops 30from tensorflow.python.ops import partitioned_variables 31from tensorflow.python.ops import variable_scope 32from tensorflow.python.ops import variables 33from tensorflow.python.platform import flags 34from tensorflow.python.platform import test 35from tensorflow.python.training import saver 36 37FLAGS = flags.FLAGS 38 39 40class GenerateVocabRemappingTest(test.TestCase): 41 """Tests for the generate_vocab_remapping() method.""" 42 43 def setUp(self): 44 self.new_vocab_file = os.path.join(self.get_temp_dir(), 45 'keyword_shifted.txt') 46 with open(self.new_vocab_file, 'w') as f: 47 f.write('\n'.join(['MISSING', 'knitting', 'eminem']) + '\n') 48 self.old_vocab_file = os.path.join(self.get_temp_dir(), 49 'keyword.txt') 50 with open(self.old_vocab_file, 'w') as f: 51 f.write('\n'.join(['knitting', 'eminem', 'MISSING']) + '\n') 52 53 @test_util.run_deprecated_v1 54 def test_generate_remapping_with_no_vocab_changes(self): 55 """Tests where vocab does not change at all.""" 56 remapping, num_present = gen_checkpoint_ops.generate_vocab_remapping( 57 new_vocab_file=self.old_vocab_file, 58 old_vocab_file=self.old_vocab_file, 59 num_new_vocab=3, 60 new_vocab_offset=0) 61 expected_remapping = range(0, 3) 62 expected_num_present = 3 63 with self.cached_session(): 64 self.assertAllEqual(expected_remapping, self.evaluate(remapping)) 65 self.assertAllEqual(expected_num_present, self.evaluate(num_present)) 66 67 def test_generate_remapping_with_shifted_vocab(self): 68 """Tests where vocab is the same, but shifted / ordered differently.""" 69 remapping, num_present = gen_checkpoint_ops.generate_vocab_remapping( 70 new_vocab_file=self.new_vocab_file, 71 old_vocab_file=self.old_vocab_file, 72 num_new_vocab=3, 73 new_vocab_offset=0) 74 expected_remapping = [2, 0, 1] 75 expected_num_present = 3 76 with self.cached_session(): 77 self.assertAllEqual(expected_remapping, self.evaluate(remapping)) 78 self.assertAllEqual(expected_num_present, self.evaluate(num_present)) 79 80 def test_generate_remapping_with_offset(self): 81 """Tests offset and num_new_vocab logic.""" 82 remapping, num_present = gen_checkpoint_ops.generate_vocab_remapping( 83 new_vocab_file=self.new_vocab_file, 84 old_vocab_file=self.old_vocab_file, 85 num_new_vocab=1, 86 new_vocab_offset=1) 87 expected_remapping = [0] 88 expected_num_present = 1 89 with self.cached_session(): 90 self.assertAllEqual(expected_remapping, self.evaluate(remapping)) 91 self.assertAllEqual(expected_num_present, self.evaluate(num_present)) 92 93 def test_generate_remapping_with_old_vocab_size(self): 94 """Tests where old_vocab_size is specified.""" 95 remapping, num_present = gen_checkpoint_ops.generate_vocab_remapping( 96 new_vocab_file=self.new_vocab_file, 97 old_vocab_file=self.old_vocab_file, 98 num_new_vocab=3, 99 new_vocab_offset=0, 100 # Old vocabulary becomes ['knitting', 'eminem']. 101 old_vocab_size=2) 102 expected_remapping = [-1, 0, 1] 103 expected_num_present = 2 104 with self.cached_session(): 105 self.assertAllEqual(expected_remapping, self.evaluate(remapping)) 106 self.assertAllEqual(expected_num_present, self.evaluate(num_present)) 107 108 109class LoadAndRemapMatrixTest(test.TestCase): 110 """Tests for the load_and_remap_matrix() op.""" 111 112 def setUp(self): 113 ops.reset_default_graph() 114 self.old_num_rows = 5 115 self.old_num_cols = 16 116 self.matrix_value = np.reshape( 117 range(0, self.old_num_rows * self.old_num_cols), (self.old_num_rows, 118 self.old_num_cols)) 119 with variable_scope.variable_scope('some_scope'): 120 matrix = variable_scope.get_variable( 121 'matrix', 122 dtype=dtypes.float32, 123 initializer=constant_op.constant( 124 self.matrix_value, dtype=dtypes.float32)) 125 self.old_tensor_name = 'some_scope/matrix' 126 127 save = saver.Saver([matrix]) 128 with self.cached_session() as sess: 129 self.evaluate(variables.global_variables_initializer()) 130 self.bundle_file = os.path.join(test.get_temp_dir(), 'bundle_checkpoint') 131 save.save(sess, self.bundle_file) 132 133 def test_load_and_remap_no_missing(self): 134 """Tests the op's load and remap where there are no missing entries.""" 135 136 # No column remapping, new weight matrix has second row, then first row. 137 row_remapping = [1, 0] 138 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( 139 ckpt_path=[self.bundle_file], 140 old_tensor_name=self.old_tensor_name, 141 row_remapping=row_remapping, 142 col_remapping=[], 143 initializing_values=[], 144 num_rows=2, 145 num_cols=self.old_num_cols) 146 with self.cached_session(): 147 self.assertAllClose(self.matrix_value[row_remapping], 148 self.evaluate(remapped_matrix)) 149 150 # No row remapping, new weight matrix has third col, then first col. 151 row_remapping = list(range(self.old_num_rows)) 152 col_remapping = [2, 0] 153 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( 154 ckpt_path=[self.bundle_file], 155 old_tensor_name=self.old_tensor_name, 156 row_remapping=row_remapping, 157 col_remapping=col_remapping, 158 initializing_values=[], 159 num_rows=len(row_remapping), 160 num_cols=len(col_remapping)) 161 with self.cached_session(): 162 self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping], 163 self.evaluate(remapped_matrix)) 164 165 # Both row and column remappings. 166 row_remapping = [1, 0, 4] 167 col_remapping = [1, 15] 168 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( 169 ckpt_path=[self.bundle_file], 170 old_tensor_name=self.old_tensor_name, 171 row_remapping=row_remapping, 172 col_remapping=col_remapping, 173 initializing_values=[], 174 num_rows=len(row_remapping), 175 num_cols=len(col_remapping)) 176 with self.cached_session(): 177 self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping], 178 self.evaluate(remapped_matrix)) 179 180 def test_load_and_remap_with_init(self): 181 """Tests the op's load and remap where there are missing entries.""" 182 init_val = 42 183 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( 184 ckpt_path=[self.bundle_file], 185 old_tensor_name=self.old_tensor_name, 186 row_remapping=[2, -1, 0], 187 col_remapping=[1, -1], 188 initializing_values=[init_val] * 4, 189 num_rows=3, 190 num_cols=2) 191 192 expected_remapped_matrix = np.reshape( 193 [33, init_val, init_val, init_val, 1, init_val], [3, 2]) 194 195 with self.cached_session(): 196 self.assertAllClose(expected_remapped_matrix, 197 self.evaluate(remapped_matrix)) 198 199 def test_load_and_remap_all_missing_rows(self): 200 """Tests when all the rows are missing and need to be initialized.""" 201 num_rows = 7 202 initializing_values = [42] * num_rows * self.old_num_cols 203 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( 204 ckpt_path=[self.bundle_file], 205 old_tensor_name=self.old_tensor_name, 206 row_remapping=[-1] * num_rows, 207 col_remapping=[], 208 initializing_values=initializing_values, 209 num_rows=num_rows, 210 num_cols=self.old_num_cols) 211 with self.cached_session(): 212 self.assertAllClose( 213 np.reshape(initializing_values, (num_rows, self.old_num_cols)), 214 self.evaluate(remapped_matrix)) 215 216 def test_load_and_remap_all_missing_rows_and_cols(self): 217 """Tests when all the rows & cols are missing and need to be initialized.""" 218 num_rows = 7 219 num_cols = 4 220 initializing_values = [42] * num_rows * num_cols 221 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( 222 ckpt_path=[self.bundle_file], 223 old_tensor_name=self.old_tensor_name, 224 row_remapping=[-1] * num_rows, 225 col_remapping=[-1] * num_cols, 226 initializing_values=initializing_values, 227 num_rows=num_rows, 228 num_cols=num_cols) 229 with self.cached_session(): 230 self.assertAllClose( 231 np.reshape(initializing_values, (num_rows, num_cols)), 232 self.evaluate(remapped_matrix)) 233 234 @test_util.run_deprecated_v1 235 def test_load_and_remap_invalid_remapping(self): 236 """Tests that errors are raised when an ID maps to multiple new IDs. 237 238 (This should usually not happen when using public APIs). 239 """ 240 invalid_remapping = [1, 0, 0, 0, 1, 2] 241 242 # Invalid row remapping. 243 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( 244 ckpt_path=[self.bundle_file], 245 old_tensor_name=self.old_tensor_name, 246 row_remapping=invalid_remapping, 247 col_remapping=[], 248 initializing_values=[], 249 num_rows=len(invalid_remapping), 250 num_cols=self.old_num_cols) 251 with self.cached_session(), self.assertRaises(errors.UnimplementedError): 252 self.evaluate(remapped_matrix) 253 254 # Invalid column remapping. 255 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( 256 ckpt_path=[self.bundle_file], 257 old_tensor_name=self.old_tensor_name, 258 row_remapping=list(range(self.old_num_rows)), 259 col_remapping=invalid_remapping, 260 initializing_values=[], 261 num_rows=self.old_num_rows, 262 num_cols=len(invalid_remapping)) 263 with self.cached_session(), self.assertRaises(errors.UnimplementedError): 264 self.evaluate(remapped_matrix) 265 266 @test_util.run_deprecated_v1 267 def test_load_and_remap_incorrect_initializing_values(self): 268 """Tests that errors are raised with incorrect number of init values.""" 269 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( 270 ckpt_path=[self.bundle_file], 271 old_tensor_name=self.old_tensor_name, 272 row_remapping=[2, -1, 0], 273 col_remapping=[1, -1], 274 # Too few initializing values - there should be 4. For some reason, 275 # initializing_values must contain no element (instead of 3 or fewer) to 276 # ensure that a seg fault would reliably occur if the check raising the 277 # InvalidArgumentError were not present. 278 initializing_values=[], 279 num_rows=3, 280 num_cols=2) 281 with self.cached_session(), self.assertRaises(errors.InvalidArgumentError): 282 self.evaluate(remapped_matrix) 283 284 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( 285 ckpt_path=[self.bundle_file], 286 old_tensor_name=self.old_tensor_name, 287 row_remapping=[2, -1, 0], 288 col_remapping=[1, -1], 289 # Too many initializing values - there should be 4. 290 initializing_values=[0] * 5, 291 num_rows=3, 292 num_cols=2) 293 with self.cached_session(), self.assertRaises(errors.InvalidArgumentError): 294 self.evaluate(remapped_matrix) 295 296 297class LoadAndRemapMatrixWithMaxRowsTest(test.TestCase): 298 """Tests for the load_and_remap_matrix() op. 299 300 (Specifically focused on the max_rows_in_memory arg and its effects on 301 TensorBundle's BundleReader and TensorSlice logic). 302 """ 303 304 def _test_loading_variable_with_max_rows(self, np_value, partitioner, 305 max_rows_in_memory): 306 """Helper function for various tests using max_rows_in_memory.""" 307 ops.reset_default_graph() 308 old_tensor_name = 'matrix_to_load_and_remap' 309 matrix = variable_scope.get_variable( 310 old_tensor_name, 311 dtype=dtypes.float32, 312 initializer=constant_op.constant(np_value, dtype=dtypes.float32), 313 partitioner=partitioner) 314 315 with self.cached_session() as sess: 316 ckpt_path = os.path.join(test.get_temp_dir(), 'temp_ckpt') 317 save = saver.Saver([matrix]) 318 self.evaluate(variables.global_variables_initializer()) 319 save.save(sess, ckpt_path) 320 num_rows, num_cols = np_value.shape 321 322 # Tests loading the entire tensor (except reversed). 323 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( 324 ckpt_path=ckpt_path, 325 old_tensor_name=old_tensor_name, 326 # Simply reverses the rows of the matrix. 327 row_remapping=list(range(num_rows - 1, -1, -1)), 328 col_remapping=[], 329 initializing_values=[], 330 num_rows=num_rows, 331 num_cols=num_cols, 332 max_rows_in_memory=max_rows_in_memory) 333 self.assertAllClose(np_value[::-1], self.evaluate(remapped_matrix)) 334 335 # Tests loading the tensor (except for the first and last rows), with 336 # uninitialized values. Requires num_rows to be at least 3 since we're 337 # skipping the first and last rows. 338 self.assertGreater(num_rows, 2) 339 prefix_rows = 2 340 suffix_rows = 3 341 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( 342 ckpt_path=ckpt_path, 343 old_tensor_name=old_tensor_name, 344 # Reverses the rows of the matrix, then prepends and appends 345 # uninitialized rows. 346 row_remapping=([-1] * prefix_rows + list(range(1, num_rows - 1)) + 347 [-1] * suffix_rows), 348 col_remapping=[], 349 initializing_values=[42] * (prefix_rows + suffix_rows) * num_cols, 350 num_rows=num_rows - 2 + prefix_rows + suffix_rows, 351 num_cols=num_cols, 352 max_rows_in_memory=max_rows_in_memory) 353 self.assertAllClose( 354 np.vstack([ 355 np.tile(42, [prefix_rows, num_cols]), np_value[1:-1], 356 np.tile(42, [suffix_rows, num_cols]) 357 ]), self.evaluate(remapped_matrix)) 358 359 # Tests when everything is taken from initializing_values. 360 new_rows = 7 361 initializing_values = [42] * new_rows * num_cols 362 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( 363 ckpt_path=ckpt_path, 364 old_tensor_name=old_tensor_name, 365 # Nothing is loaded from the old tensor. 366 row_remapping=[-1] * new_rows, 367 col_remapping=[], 368 initializing_values=initializing_values, 369 num_rows=new_rows, 370 num_cols=num_cols, 371 max_rows_in_memory=max_rows_in_memory) 372 self.assertAllClose( 373 np.reshape(initializing_values, (new_rows, num_cols)), 374 self.evaluate(remapped_matrix)) 375 376 @test_util.run_deprecated_v1 377 def test_loading_rows_divisible_by_max_rows(self): 378 """Tests loading normal var when rows are evenly divisible by max_rows.""" 379 self._test_loading_variable_with_max_rows( 380 np_value=np.reshape(list(range(0, 36)), (9, 4)), 381 partitioner=None, 382 # 9 is evenly divisible by 3. 383 max_rows_in_memory=3) 384 385 @test_util.run_deprecated_v1 386 def test_loading_rows_not_divisible_by_max_rows(self): 387 """Tests loading normal var when rows aren't divisible by max_rows.""" 388 self._test_loading_variable_with_max_rows( 389 np_value=np.reshape(list(range(0, 36)), (9, 4)), 390 partitioner=None, 391 # 9 is not evenly divisible by 4. 392 max_rows_in_memory=4) 393 394 @test_util.run_deprecated_v1 395 def test_loading_rows_less_than_max_rows(self): 396 """Tests loading normal var as a single slice. 397 398 (When the specified max_rows_in_memory is larger than the number of rows) 399 """ 400 self._test_loading_variable_with_max_rows( 401 np_value=np.reshape(list(range(0, 36)), (9, 4)), 402 partitioner=None, 403 # 10 > 9. 404 max_rows_in_memory=10) 405 406 @test_util.run_deprecated_v1 407 def test_loading_no_max_rows(self): 408 """Tests loading normal var as a single slice with no valid max_rows.""" 409 self._test_loading_variable_with_max_rows( 410 np_value=np.reshape(list(range(0, 18)), (6, 3)), 411 partitioner=None, 412 max_rows_in_memory=-1) 413 414 @test_util.run_deprecated_v1 415 def test_loading_partitions_equals_max_rows(self): 416 """Tests loading partitioned var sliced on partition boundary.""" 417 self._test_loading_variable_with_max_rows( 418 np_value=np.reshape(list(range(0, 36)), (9, 4)), 419 partitioner=partitioned_variables.fixed_size_partitioner(3), 420 # With a tensor of shape [9, 3] and 3 partitions, each partition has 421 # exactly 3 rows. 422 max_rows_in_memory=3) 423 424 @test_util.run_deprecated_v1 425 def test_loading_partitions_greater_than_max_rows(self): 426 """Tests loading partitioned var with more slices than partitions.""" 427 self._test_loading_variable_with_max_rows( 428 np_value=np.reshape(list(range(0, 36)), (9, 4)), 429 partitioner=partitioned_variables.fixed_size_partitioner(3), 430 # Even though each partition has 3 rows, we'll only load the tensor one 431 # row at a time. 432 max_rows_in_memory=1) 433 434 @test_util.run_deprecated_v1 435 def test_loading_partitions_less_than_max_rows(self): 436 """Tests loading partitioned var as a single slice. 437 438 (When the specified max_rows_in_memory is larger than the number of rows) 439 """ 440 self._test_loading_variable_with_max_rows( 441 np_value=np.reshape(list(range(0, 36)), (9, 4)), 442 partitioner=partitioned_variables.fixed_size_partitioner(3), 443 max_rows_in_memory=10) 444 445 @test_util.run_deprecated_v1 446 def test_loading_partitions_no_max_rows(self): 447 """Tests loading partitioned var as single slice with no valid max_rows.""" 448 self._test_loading_variable_with_max_rows( 449 np_value=np.reshape(list(range(0, 36)), (9, 4)), 450 partitioner=partitioned_variables.fixed_size_partitioner(3), 451 max_rows_in_memory=-1) 452 453 454if __name__ == '__main__': 455 test.main() 456