1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import gc 22import glob 23import os 24import shutil 25import tempfile 26import time 27 28import numpy as np 29from six.moves import xrange # pylint: disable=redefined-builtin 30import tensorflow as tf 31 32# pylint: disable=g-bad-import-order 33import tensorflow.contrib.eager as tfe 34from tensorflow.contrib.eager.python.examples.spinn import data 35from third_party.examples.eager.spinn import spinn 36from tensorflow.contrib.summary import summary_test_util 37from tensorflow.python.eager import test 38from tensorflow.python.framework import test_util 39from tensorflow.python.training import checkpoint_management 40from tensorflow.python.training.tracking import util as trackable_utils 41# pylint: enable=g-bad-import-order 42 43 44def _generate_synthetic_snli_data_batch(sequence_length, 45 batch_size, 46 vocab_size): 47 """Generate a fake batch of SNLI data for testing.""" 48 with tf.device("cpu:0"): 49 labels = tf.random_uniform([batch_size], minval=1, maxval=4, dtype=tf.int64) 50 prem = tf.random_uniform( 51 (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64) 52 prem_trans = tf.constant(np.array( 53 [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, 54 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, 55 3, 2, 2]] * batch_size, dtype=np.int64).T) 56 hypo = tf.random_uniform( 57 (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64) 58 hypo_trans = tf.constant(np.array( 59 [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, 60 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, 61 3, 2, 2]] * batch_size, dtype=np.int64).T) 62 if tfe.num_gpus(): 63 labels = labels.gpu() 64 prem = prem.gpu() 65 prem_trans = prem_trans.gpu() 66 hypo = hypo.gpu() 67 hypo_trans = hypo_trans.gpu() 68 return labels, prem, prem_trans, hypo, hypo_trans 69 70 71def _test_spinn_config(d_embed, d_out, logdir=None, inference_sentences=None): 72 """Generate a config tuple for testing. 73 74 Args: 75 d_embed: Embedding dimensions. 76 d_out: Model output dimensions. 77 logdir: Optional logdir. 78 inference_sentences: A 2-tuple of strings representing the sentences (with 79 binary parsing result), e.g., 80 ("( ( The dog ) ( ( is running ) . ) )", "( ( The dog ) ( moves . ) )"). 81 82 Returns: 83 A config tuple. 84 """ 85 config_tuple = collections.namedtuple( 86 "Config", ["d_hidden", "d_proj", "d_tracker", "predict", 87 "embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp", 88 "d_out", "projection", "lr", "batch_size", "epochs", 89 "force_cpu", "logdir", "log_every", "dev_every", "save_every", 90 "lr_decay_every", "lr_decay_by", "inference_premise", 91 "inference_hypothesis"]) 92 93 inference_premise = inference_sentences[0] if inference_sentences else None 94 inference_hypothesis = inference_sentences[1] if inference_sentences else None 95 return config_tuple( 96 d_hidden=d_embed, 97 d_proj=d_embed * 2, 98 d_tracker=8, 99 predict=False, 100 embed_dropout=0.1, 101 mlp_dropout=0.1, 102 n_mlp_layers=2, 103 d_mlp=32, 104 d_out=d_out, 105 projection=True, 106 lr=2e-2, 107 batch_size=2, 108 epochs=20, 109 force_cpu=False, 110 logdir=logdir, 111 log_every=1, 112 dev_every=2, 113 save_every=2, 114 lr_decay_every=1, 115 lr_decay_by=0.75, 116 inference_premise=inference_premise, 117 inference_hypothesis=inference_hypothesis) 118 119 120class SpinnTest(test_util.TensorFlowTestCase): 121 122 def setUp(self): 123 super(SpinnTest, self).setUp() 124 self._test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" 125 self._temp_data_dir = tempfile.mkdtemp() 126 127 def tearDown(self): 128 shutil.rmtree(self._temp_data_dir) 129 super(SpinnTest, self).tearDown() 130 131 def testBundle(self): 132 with tf.device(self._test_device): 133 lstm_iter = [np.array([[0, 1], [2, 3]], dtype=np.float32), 134 np.array([[0, -1], [-2, -3]], dtype=np.float32), 135 np.array([[0, 2], [4, 6]], dtype=np.float32), 136 np.array([[0, -2], [-4, -6]], dtype=np.float32)] 137 out = spinn._bundle(lstm_iter) 138 139 self.assertEqual(2, len(out)) 140 self.assertEqual(tf.float32, out[0].dtype) 141 self.assertEqual(tf.float32, out[1].dtype) 142 self.assertAllEqual(np.array([[0, 2, 0, -2, 0, 4, 0, -4]]).T, 143 out[0].numpy()) 144 self.assertAllEqual(np.array([[1, 3, -1, -3, 2, 6, -2, -6]]).T, 145 out[1].numpy()) 146 147 def testUnbunbdle(self): 148 with tf.device(self._test_device): 149 state = [np.array([[0, 1, 2], [3, 4, 5]], dtype=np.float32), 150 np.array([[0, -1, -2], [-3, -4, -5]], dtype=np.float32)] 151 out = spinn._unbundle(state) 152 153 self.assertEqual(2, len(out)) 154 self.assertEqual(tf.float32, out[0].dtype) 155 self.assertEqual(tf.float32, out[1].dtype) 156 self.assertAllEqual(np.array([[0, 1, 2, 0, -1, -2]]), 157 out[0].numpy()) 158 self.assertAllEqual(np.array([[3, 4, 5, -3, -4, -5]]), 159 out[1].numpy()) 160 161 def testReducer(self): 162 with tf.device(self._test_device): 163 batch_size = 3 164 size = 10 165 tracker_size = 8 166 reducer = spinn.Reducer(size, tracker_size=tracker_size) 167 168 left_in = [] 169 right_in = [] 170 tracking = [] 171 for _ in range(batch_size): 172 left_in.append(tf.random_normal((1, size * 2))) 173 right_in.append(tf.random_normal((1, size * 2))) 174 tracking.append(tf.random_normal((1, tracker_size * 2))) 175 176 out = reducer(left_in, right_in, tracking=tracking) 177 self.assertEqual(batch_size, len(out)) 178 self.assertEqual(tf.float32, out[0].dtype) 179 self.assertEqual((1, size * 2), out[0].shape) 180 181 def testReduceTreeLSTM(self): 182 with tf.device(self._test_device): 183 size = 10 184 tracker_size = 8 185 reducer = spinn.Reducer(size, tracker_size=tracker_size) 186 187 lstm_in = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 188 [0, -1, -2, -3, -4, -5, -6, -7, -8, -9]], 189 dtype=np.float32) 190 c1 = np.array([[0, 1], [2, 3]], dtype=np.float32) 191 c2 = np.array([[0, -1], [-2, -3]], dtype=np.float32) 192 193 h, c = reducer._tree_lstm(c1, c2, lstm_in) 194 self.assertEqual(tf.float32, h.dtype) 195 self.assertEqual(tf.float32, c.dtype) 196 self.assertEqual((2, 2), h.shape) 197 self.assertEqual((2, 2), c.shape) 198 199 def testTracker(self): 200 with tf.device(self._test_device): 201 batch_size = 2 202 size = 10 203 tracker_size = 8 204 buffer_length = 18 205 stack_size = 3 206 207 tracker = spinn.Tracker(tracker_size, False) 208 tracker.reset_state() 209 210 # Create dummy inputs for testing. 211 bufs = [] 212 buf = [] 213 for _ in range(buffer_length): 214 buf.append(tf.random_normal((batch_size, size * 2))) 215 bufs.append(buf) 216 self.assertEqual(1, len(bufs)) 217 self.assertEqual(buffer_length, len(bufs[0])) 218 self.assertEqual((batch_size, size * 2), bufs[0][0].shape) 219 220 stacks = [] 221 stack = [] 222 for _ in range(stack_size): 223 stack.append(tf.random_normal((batch_size, size * 2))) 224 stacks.append(stack) 225 self.assertEqual(1, len(stacks)) 226 self.assertEqual(3, len(stacks[0])) 227 self.assertEqual((batch_size, size * 2), stacks[0][0].shape) 228 229 for _ in range(2): 230 out1, out2 = tracker(bufs, stacks) 231 self.assertIsNone(out2) 232 self.assertEqual(batch_size, len(out1)) 233 self.assertEqual(tf.float32, out1[0].dtype) 234 self.assertEqual((1, tracker_size * 2), out1[0].shape) 235 236 self.assertEqual(tf.float32, tracker.state.c.dtype) 237 self.assertEqual((batch_size, tracker_size), tracker.state.c.shape) 238 self.assertEqual(tf.float32, tracker.state.h.dtype) 239 self.assertEqual((batch_size, tracker_size), tracker.state.h.shape) 240 241 def testSPINN(self): 242 with tf.device(self._test_device): 243 embedding_dims = 10 244 d_tracker = 8 245 sequence_length = 15 246 num_transitions = 27 247 248 config_tuple = collections.namedtuple( 249 "Config", ["d_hidden", "d_proj", "d_tracker", "predict"]) 250 config = config_tuple( 251 embedding_dims, embedding_dims * 2, d_tracker, False) 252 s = spinn.SPINN(config) 253 254 # Create some fake data. 255 buffers = tf.random_normal((sequence_length, 1, config.d_proj)) 256 transitions = tf.constant( 257 [[3], [3], [2], [3], [3], [3], [2], [2], [2], [3], [3], [3], 258 [2], [3], [3], [2], [2], [3], [3], [3], [2], [2], [2], [2], 259 [3], [2], [2]], dtype=tf.int64) 260 self.assertEqual(tf.int64, transitions.dtype) 261 self.assertEqual((num_transitions, 1), transitions.shape) 262 263 out = s(buffers, transitions, training=True) 264 self.assertEqual(tf.float32, out.dtype) 265 self.assertEqual((1, embedding_dims), out.shape) 266 267 def testSNLIClassifierAndTrainer(self): 268 with tf.device(self._test_device): 269 vocab_size = 40 270 batch_size = 2 271 d_embed = 10 272 sequence_length = 15 273 d_out = 4 274 275 config = _test_spinn_config(d_embed, d_out) 276 277 # Create fake embedding matrix. 278 embed = tf.random_normal((vocab_size, d_embed)) 279 280 model = spinn.SNLIClassifier(config, embed) 281 trainer = spinn.SNLIClassifierTrainer(model, config.lr) 282 283 (labels, prem, prem_trans, hypo, 284 hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, 285 batch_size, 286 vocab_size) 287 288 # Invoke model under non-training mode. 289 logits = model(prem, prem_trans, hypo, hypo_trans, training=False) 290 self.assertEqual(tf.float32, logits.dtype) 291 self.assertEqual((batch_size, d_out), logits.shape) 292 293 # Invoke model under training model. 294 logits = model(prem, prem_trans, hypo, hypo_trans, training=True) 295 self.assertEqual(tf.float32, logits.dtype) 296 self.assertEqual((batch_size, d_out), logits.shape) 297 298 # Calculate loss. 299 loss1 = trainer.loss(labels, logits) 300 self.assertEqual(tf.float32, loss1.dtype) 301 self.assertEqual((), loss1.shape) 302 303 loss2, logits = trainer.train_batch( 304 labels, prem, prem_trans, hypo, hypo_trans) 305 self.assertEqual(tf.float32, loss2.dtype) 306 self.assertEqual((), loss2.shape) 307 self.assertEqual(tf.float32, logits.dtype) 308 self.assertEqual((batch_size, d_out), logits.shape) 309 # Training on the batch should have led to a change in the loss value. 310 self.assertNotEqual(loss1.numpy(), loss2.numpy()) 311 312 def _create_test_data(self, snli_1_0_dir): 313 fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") 314 os.makedirs(snli_1_0_dir) 315 316 # Four sentences in total. 317 with open(fake_train_file, "wt") as f: 318 f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" 319 "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" 320 "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") 321 f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" 322 "DummySentence1Parse\tDummySentence2Parse\t" 323 "Foo bar.\tfoo baz.\t" 324 "4705552913.jpg#2\t4705552913.jpg#2r1n\t" 325 "neutral\tentailment\tneutral\tneutral\tneutral\n") 326 f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" 327 "DummySentence1Parse\tDummySentence2Parse\t" 328 "Foo bar.\tfoo baz.\t" 329 "4705552913.jpg#2\t4705552913.jpg#2r1n\t" 330 "neutral\tentailment\tneutral\tneutral\tneutral\n") 331 f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" 332 "DummySentence1Parse\tDummySentence2Parse\t" 333 "Foo bar.\tfoo baz.\t" 334 "4705552913.jpg#2\t4705552913.jpg#2r1n\t" 335 "neutral\tentailment\tneutral\tneutral\tneutral\n") 336 f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" 337 "DummySentence1Parse\tDummySentence2Parse\t" 338 "Foo bar.\tfoo baz.\t" 339 "4705552913.jpg#2\t4705552913.jpg#2r1n\t" 340 "neutral\tentailment\tneutral\tneutral\tneutral\n") 341 342 glove_dir = os.path.join(self._temp_data_dir, "glove") 343 os.makedirs(glove_dir) 344 glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") 345 346 words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] 347 with open(glove_file, "wt") as f: 348 for i, word in enumerate(words): 349 f.write("%s " % word) 350 for j in range(data.WORD_VECTOR_LEN): 351 f.write("%.5f" % (i * 0.1)) 352 if j < data.WORD_VECTOR_LEN - 1: 353 f.write(" ") 354 else: 355 f.write("\n") 356 357 return fake_train_file 358 359 def testInferSpinnWorks(self): 360 """Test inference with the spinn model.""" 361 snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") 362 self._create_test_data(snli_1_0_dir) 363 364 vocab = data.load_vocabulary(self._temp_data_dir) 365 word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) 366 367 config = _test_spinn_config( 368 data.WORD_VECTOR_LEN, 4, 369 logdir=os.path.join(self._temp_data_dir, "logdir"), 370 inference_sentences=("( foo ( bar . ) )", "( bar ( foo . ) )")) 371 logits = spinn.train_or_infer_spinn( 372 embed, word2index, None, None, None, config) 373 self.assertEqual(tf.float32, logits.dtype) 374 self.assertEqual((3,), logits.shape) 375 376 def testInferSpinnThrowsErrorIfOnlyOneSentenceIsSpecified(self): 377 snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") 378 self._create_test_data(snli_1_0_dir) 379 380 vocab = data.load_vocabulary(self._temp_data_dir) 381 word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) 382 383 config = _test_spinn_config( 384 data.WORD_VECTOR_LEN, 4, 385 logdir=os.path.join(self._temp_data_dir, "logdir"), 386 inference_sentences=("( foo ( bar . ) )", None)) 387 with self.assertRaises(ValueError): 388 spinn.train_or_infer_spinn(embed, word2index, None, None, None, config) 389 390 def testTrainSpinn(self): 391 """Test with fake toy SNLI data and GloVe vectors.""" 392 393 # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. 394 snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") 395 fake_train_file = self._create_test_data(snli_1_0_dir) 396 397 vocab = data.load_vocabulary(self._temp_data_dir) 398 word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) 399 400 train_data = data.SnliData(fake_train_file, word2index) 401 dev_data = data.SnliData(fake_train_file, word2index) 402 test_data = data.SnliData(fake_train_file, word2index) 403 404 # 2. Create a fake config. 405 config = _test_spinn_config( 406 data.WORD_VECTOR_LEN, 4, 407 logdir=os.path.join(self._temp_data_dir, "logdir")) 408 409 # 3. Test training of a SPINN model. 410 trainer = spinn.train_or_infer_spinn( 411 embed, word2index, train_data, dev_data, test_data, config) 412 413 # 4. Load train loss values from the summary files and verify that they 414 # decrease with training. 415 summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0] 416 events = summary_test_util.events_from_file(summary_file) 417 train_losses = [event.summary.value[0].simple_value for event in events 418 if event.summary.value 419 and event.summary.value[0].tag == "train/loss"] 420 self.assertEqual(config.epochs, len(train_losses)) 421 422 # 5. Verify that checkpoints exist and contains all the expected variables. 423 self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) 424 object_graph = trackable_utils.object_metadata( 425 checkpoint_management.latest_checkpoint(config.logdir)) 426 ckpt_variable_names = set() 427 for node in object_graph.nodes: 428 for attribute in node.attributes: 429 ckpt_variable_names.add(attribute.full_name) 430 self.assertIn("global_step", ckpt_variable_names) 431 for v in trainer.variables: 432 variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name 433 self.assertIn(variable_name, ckpt_variable_names) 434 435 436class EagerSpinnSNLIClassifierBenchmark(test.Benchmark): 437 438 def benchmarkEagerSpinnSNLIClassifier(self): 439 test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" 440 with tf.device(test_device): 441 burn_in_iterations = 2 442 benchmark_iterations = 10 443 444 vocab_size = 1000 445 batch_size = 128 446 sequence_length = 15 447 d_embed = 200 448 d_out = 4 449 450 embed = tf.random_normal((vocab_size, d_embed)) 451 452 config = _test_spinn_config(d_embed, d_out) 453 model = spinn.SNLIClassifier(config, embed) 454 trainer = spinn.SNLIClassifierTrainer(model, config.lr) 455 456 (labels, prem, prem_trans, hypo, 457 hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, 458 batch_size, 459 vocab_size) 460 461 for _ in range(burn_in_iterations): 462 trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) 463 464 gc.collect() 465 start_time = time.time() 466 for _ in xrange(benchmark_iterations): 467 trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) 468 wall_time = time.time() - start_time 469 # Named "examples"_per_sec to conform with other benchmarks. 470 extras = {"examples_per_sec": benchmark_iterations / wall_time} 471 self.report_benchmark( 472 name="Eager_SPINN_SNLIClassifier_Benchmark", 473 iters=benchmark_iterations, 474 wall_time=wall_time, 475 extras=extras) 476 477 478if __name__ == "__main__": 479 test.main() 480