• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import gc
22import glob
23import os
24import shutil
25import tempfile
26import time
27
28import numpy as np
29from six.moves import xrange  # pylint: disable=redefined-builtin
30import tensorflow as tf
31
32# pylint: disable=g-bad-import-order
33import tensorflow.contrib.eager as tfe
34from tensorflow.contrib.eager.python.examples.spinn import data
35from third_party.examples.eager.spinn import spinn
36from tensorflow.contrib.summary import summary_test_util
37from tensorflow.python.eager import test
38from tensorflow.python.framework import test_util
39from tensorflow.python.training import checkpoint_management
40from tensorflow.python.training.tracking import util as trackable_utils
41# pylint: enable=g-bad-import-order
42
43
44def _generate_synthetic_snli_data_batch(sequence_length,
45                                        batch_size,
46                                        vocab_size):
47  """Generate a fake batch of SNLI data for testing."""
48  with tf.device("cpu:0"):
49    labels = tf.random_uniform([batch_size], minval=1, maxval=4, dtype=tf.int64)
50    prem = tf.random_uniform(
51        (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64)
52    prem_trans = tf.constant(np.array(
53        [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3,
54          2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2,
55          3, 2, 2]] * batch_size, dtype=np.int64).T)
56    hypo = tf.random_uniform(
57        (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64)
58    hypo_trans = tf.constant(np.array(
59        [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3,
60          2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2,
61          3, 2, 2]] * batch_size, dtype=np.int64).T)
62  if tfe.num_gpus():
63    labels = labels.gpu()
64    prem = prem.gpu()
65    prem_trans = prem_trans.gpu()
66    hypo = hypo.gpu()
67    hypo_trans = hypo_trans.gpu()
68  return labels, prem, prem_trans, hypo, hypo_trans
69
70
71def _test_spinn_config(d_embed, d_out, logdir=None, inference_sentences=None):
72  """Generate a config tuple for testing.
73
74  Args:
75    d_embed: Embedding dimensions.
76    d_out: Model output dimensions.
77    logdir: Optional logdir.
78    inference_sentences: A 2-tuple of strings representing the sentences (with
79      binary parsing result), e.g.,
80      ("( ( The dog ) ( ( is running ) . ) )", "( ( The dog ) ( moves . ) )").
81
82  Returns:
83    A config tuple.
84  """
85  config_tuple = collections.namedtuple(
86      "Config", ["d_hidden", "d_proj", "d_tracker", "predict",
87                 "embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp",
88                 "d_out", "projection", "lr", "batch_size", "epochs",
89                 "force_cpu", "logdir", "log_every", "dev_every", "save_every",
90                 "lr_decay_every", "lr_decay_by", "inference_premise",
91                 "inference_hypothesis"])
92
93  inference_premise = inference_sentences[0] if inference_sentences else None
94  inference_hypothesis = inference_sentences[1] if inference_sentences else None
95  return config_tuple(
96      d_hidden=d_embed,
97      d_proj=d_embed * 2,
98      d_tracker=8,
99      predict=False,
100      embed_dropout=0.1,
101      mlp_dropout=0.1,
102      n_mlp_layers=2,
103      d_mlp=32,
104      d_out=d_out,
105      projection=True,
106      lr=2e-2,
107      batch_size=2,
108      epochs=20,
109      force_cpu=False,
110      logdir=logdir,
111      log_every=1,
112      dev_every=2,
113      save_every=2,
114      lr_decay_every=1,
115      lr_decay_by=0.75,
116      inference_premise=inference_premise,
117      inference_hypothesis=inference_hypothesis)
118
119
120class SpinnTest(test_util.TensorFlowTestCase):
121
122  def setUp(self):
123    super(SpinnTest, self).setUp()
124    self._test_device = "gpu:0" if tfe.num_gpus() else "cpu:0"
125    self._temp_data_dir = tempfile.mkdtemp()
126
127  def tearDown(self):
128    shutil.rmtree(self._temp_data_dir)
129    super(SpinnTest, self).tearDown()
130
131  def testBundle(self):
132    with tf.device(self._test_device):
133      lstm_iter = [np.array([[0, 1], [2, 3]], dtype=np.float32),
134                   np.array([[0, -1], [-2, -3]], dtype=np.float32),
135                   np.array([[0, 2], [4, 6]], dtype=np.float32),
136                   np.array([[0, -2], [-4, -6]], dtype=np.float32)]
137      out = spinn._bundle(lstm_iter)
138
139      self.assertEqual(2, len(out))
140      self.assertEqual(tf.float32, out[0].dtype)
141      self.assertEqual(tf.float32, out[1].dtype)
142      self.assertAllEqual(np.array([[0, 2, 0, -2, 0, 4, 0, -4]]).T,
143                          out[0].numpy())
144      self.assertAllEqual(np.array([[1, 3, -1, -3, 2, 6, -2, -6]]).T,
145                          out[1].numpy())
146
147  def testUnbunbdle(self):
148    with tf.device(self._test_device):
149      state = [np.array([[0, 1, 2], [3, 4, 5]], dtype=np.float32),
150               np.array([[0, -1, -2], [-3, -4, -5]], dtype=np.float32)]
151      out = spinn._unbundle(state)
152
153      self.assertEqual(2, len(out))
154      self.assertEqual(tf.float32, out[0].dtype)
155      self.assertEqual(tf.float32, out[1].dtype)
156      self.assertAllEqual(np.array([[0, 1, 2, 0, -1, -2]]),
157                          out[0].numpy())
158      self.assertAllEqual(np.array([[3, 4, 5, -3, -4, -5]]),
159                          out[1].numpy())
160
161  def testReducer(self):
162    with tf.device(self._test_device):
163      batch_size = 3
164      size = 10
165      tracker_size = 8
166      reducer = spinn.Reducer(size, tracker_size=tracker_size)
167
168      left_in = []
169      right_in = []
170      tracking = []
171      for _ in range(batch_size):
172        left_in.append(tf.random_normal((1, size * 2)))
173        right_in.append(tf.random_normal((1, size * 2)))
174        tracking.append(tf.random_normal((1, tracker_size * 2)))
175
176      out = reducer(left_in, right_in, tracking=tracking)
177      self.assertEqual(batch_size, len(out))
178      self.assertEqual(tf.float32, out[0].dtype)
179      self.assertEqual((1, size * 2), out[0].shape)
180
181  def testReduceTreeLSTM(self):
182    with tf.device(self._test_device):
183      size = 10
184      tracker_size = 8
185      reducer = spinn.Reducer(size, tracker_size=tracker_size)
186
187      lstm_in = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
188                          [0, -1, -2, -3, -4, -5, -6, -7, -8, -9]],
189                         dtype=np.float32)
190      c1 = np.array([[0, 1], [2, 3]], dtype=np.float32)
191      c2 = np.array([[0, -1], [-2, -3]], dtype=np.float32)
192
193      h, c = reducer._tree_lstm(c1, c2, lstm_in)
194      self.assertEqual(tf.float32, h.dtype)
195      self.assertEqual(tf.float32, c.dtype)
196      self.assertEqual((2, 2), h.shape)
197      self.assertEqual((2, 2), c.shape)
198
199  def testTracker(self):
200    with tf.device(self._test_device):
201      batch_size = 2
202      size = 10
203      tracker_size = 8
204      buffer_length = 18
205      stack_size = 3
206
207      tracker = spinn.Tracker(tracker_size, False)
208      tracker.reset_state()
209
210      # Create dummy inputs for testing.
211      bufs = []
212      buf = []
213      for _ in range(buffer_length):
214        buf.append(tf.random_normal((batch_size, size * 2)))
215      bufs.append(buf)
216      self.assertEqual(1, len(bufs))
217      self.assertEqual(buffer_length, len(bufs[0]))
218      self.assertEqual((batch_size, size * 2), bufs[0][0].shape)
219
220      stacks = []
221      stack = []
222      for _ in range(stack_size):
223        stack.append(tf.random_normal((batch_size, size * 2)))
224      stacks.append(stack)
225      self.assertEqual(1, len(stacks))
226      self.assertEqual(3, len(stacks[0]))
227      self.assertEqual((batch_size, size * 2), stacks[0][0].shape)
228
229      for _ in range(2):
230        out1, out2 = tracker(bufs, stacks)
231        self.assertIsNone(out2)
232        self.assertEqual(batch_size, len(out1))
233        self.assertEqual(tf.float32, out1[0].dtype)
234        self.assertEqual((1, tracker_size * 2), out1[0].shape)
235
236        self.assertEqual(tf.float32, tracker.state.c.dtype)
237        self.assertEqual((batch_size, tracker_size), tracker.state.c.shape)
238        self.assertEqual(tf.float32, tracker.state.h.dtype)
239        self.assertEqual((batch_size, tracker_size), tracker.state.h.shape)
240
241  def testSPINN(self):
242    with tf.device(self._test_device):
243      embedding_dims = 10
244      d_tracker = 8
245      sequence_length = 15
246      num_transitions = 27
247
248      config_tuple = collections.namedtuple(
249          "Config", ["d_hidden", "d_proj", "d_tracker", "predict"])
250      config = config_tuple(
251          embedding_dims, embedding_dims * 2, d_tracker, False)
252      s = spinn.SPINN(config)
253
254      # Create some fake data.
255      buffers = tf.random_normal((sequence_length, 1, config.d_proj))
256      transitions = tf.constant(
257          [[3], [3], [2], [3], [3], [3], [2], [2], [2], [3], [3], [3],
258           [2], [3], [3], [2], [2], [3], [3], [3], [2], [2], [2], [2],
259           [3], [2], [2]], dtype=tf.int64)
260      self.assertEqual(tf.int64, transitions.dtype)
261      self.assertEqual((num_transitions, 1), transitions.shape)
262
263      out = s(buffers, transitions, training=True)
264      self.assertEqual(tf.float32, out.dtype)
265      self.assertEqual((1, embedding_dims), out.shape)
266
267  def testSNLIClassifierAndTrainer(self):
268    with tf.device(self._test_device):
269      vocab_size = 40
270      batch_size = 2
271      d_embed = 10
272      sequence_length = 15
273      d_out = 4
274
275      config = _test_spinn_config(d_embed, d_out)
276
277      # Create fake embedding matrix.
278      embed = tf.random_normal((vocab_size, d_embed))
279
280      model = spinn.SNLIClassifier(config, embed)
281      trainer = spinn.SNLIClassifierTrainer(model, config.lr)
282
283      (labels, prem, prem_trans, hypo,
284       hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length,
285                                                         batch_size,
286                                                         vocab_size)
287
288      # Invoke model under non-training mode.
289      logits = model(prem, prem_trans, hypo, hypo_trans, training=False)
290      self.assertEqual(tf.float32, logits.dtype)
291      self.assertEqual((batch_size, d_out), logits.shape)
292
293      # Invoke model under training model.
294      logits = model(prem, prem_trans, hypo, hypo_trans, training=True)
295      self.assertEqual(tf.float32, logits.dtype)
296      self.assertEqual((batch_size, d_out), logits.shape)
297
298      # Calculate loss.
299      loss1 = trainer.loss(labels, logits)
300      self.assertEqual(tf.float32, loss1.dtype)
301      self.assertEqual((), loss1.shape)
302
303      loss2, logits = trainer.train_batch(
304          labels, prem, prem_trans, hypo, hypo_trans)
305      self.assertEqual(tf.float32, loss2.dtype)
306      self.assertEqual((), loss2.shape)
307      self.assertEqual(tf.float32, logits.dtype)
308      self.assertEqual((batch_size, d_out), logits.shape)
309      # Training on the batch should have led to a change in the loss value.
310      self.assertNotEqual(loss1.numpy(), loss2.numpy())
311
312  def _create_test_data(self, snli_1_0_dir):
313    fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
314    os.makedirs(snli_1_0_dir)
315
316    # Four sentences in total.
317    with open(fake_train_file, "wt") as f:
318      f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
319              "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
320              "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
321      f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t"
322              "DummySentence1Parse\tDummySentence2Parse\t"
323              "Foo bar.\tfoo baz.\t"
324              "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
325              "neutral\tentailment\tneutral\tneutral\tneutral\n")
326      f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t"
327              "DummySentence1Parse\tDummySentence2Parse\t"
328              "Foo bar.\tfoo baz.\t"
329              "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
330              "neutral\tentailment\tneutral\tneutral\tneutral\n")
331      f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t"
332              "DummySentence1Parse\tDummySentence2Parse\t"
333              "Foo bar.\tfoo baz.\t"
334              "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
335              "neutral\tentailment\tneutral\tneutral\tneutral\n")
336      f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t"
337              "DummySentence1Parse\tDummySentence2Parse\t"
338              "Foo bar.\tfoo baz.\t"
339              "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
340              "neutral\tentailment\tneutral\tneutral\tneutral\n")
341
342    glove_dir = os.path.join(self._temp_data_dir, "glove")
343    os.makedirs(glove_dir)
344    glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
345
346    words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"]
347    with open(glove_file, "wt") as f:
348      for i, word in enumerate(words):
349        f.write("%s " % word)
350        for j in range(data.WORD_VECTOR_LEN):
351          f.write("%.5f" % (i * 0.1))
352          if j < data.WORD_VECTOR_LEN - 1:
353            f.write(" ")
354          else:
355            f.write("\n")
356
357    return fake_train_file
358
359  def testInferSpinnWorks(self):
360    """Test inference with the spinn model."""
361    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
362    self._create_test_data(snli_1_0_dir)
363
364    vocab = data.load_vocabulary(self._temp_data_dir)
365    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)
366
367    config = _test_spinn_config(
368        data.WORD_VECTOR_LEN, 4,
369        logdir=os.path.join(self._temp_data_dir, "logdir"),
370        inference_sentences=("( foo ( bar . ) )", "( bar ( foo . ) )"))
371    logits = spinn.train_or_infer_spinn(
372        embed, word2index, None, None, None, config)
373    self.assertEqual(tf.float32, logits.dtype)
374    self.assertEqual((3,), logits.shape)
375
376  def testInferSpinnThrowsErrorIfOnlyOneSentenceIsSpecified(self):
377    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
378    self._create_test_data(snli_1_0_dir)
379
380    vocab = data.load_vocabulary(self._temp_data_dir)
381    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)
382
383    config = _test_spinn_config(
384        data.WORD_VECTOR_LEN, 4,
385        logdir=os.path.join(self._temp_data_dir, "logdir"),
386        inference_sentences=("( foo ( bar . ) )", None))
387    with self.assertRaises(ValueError):
388      spinn.train_or_infer_spinn(embed, word2index, None, None, None, config)
389
390  def testTrainSpinn(self):
391    """Test with fake toy SNLI data and GloVe vectors."""
392
393    # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
394    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
395    fake_train_file = self._create_test_data(snli_1_0_dir)
396
397    vocab = data.load_vocabulary(self._temp_data_dir)
398    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)
399
400    train_data = data.SnliData(fake_train_file, word2index)
401    dev_data = data.SnliData(fake_train_file, word2index)
402    test_data = data.SnliData(fake_train_file, word2index)
403
404    # 2. Create a fake config.
405    config = _test_spinn_config(
406        data.WORD_VECTOR_LEN, 4,
407        logdir=os.path.join(self._temp_data_dir, "logdir"))
408
409    # 3. Test training of a SPINN model.
410    trainer = spinn.train_or_infer_spinn(
411        embed, word2index, train_data, dev_data, test_data, config)
412
413    # 4. Load train loss values from the summary files and verify that they
414    #    decrease with training.
415    summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0]
416    events = summary_test_util.events_from_file(summary_file)
417    train_losses = [event.summary.value[0].simple_value for event in events
418                    if event.summary.value
419                    and event.summary.value[0].tag == "train/loss"]
420    self.assertEqual(config.epochs, len(train_losses))
421
422    # 5. Verify that checkpoints exist and contains all the expected variables.
423    self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
424    object_graph = trackable_utils.object_metadata(
425        checkpoint_management.latest_checkpoint(config.logdir))
426    ckpt_variable_names = set()
427    for node in object_graph.nodes:
428      for attribute in node.attributes:
429        ckpt_variable_names.add(attribute.full_name)
430    self.assertIn("global_step", ckpt_variable_names)
431    for v in trainer.variables:
432      variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name
433      self.assertIn(variable_name, ckpt_variable_names)
434
435
436class EagerSpinnSNLIClassifierBenchmark(test.Benchmark):
437
438  def benchmarkEagerSpinnSNLIClassifier(self):
439    test_device = "gpu:0" if tfe.num_gpus() else "cpu:0"
440    with tf.device(test_device):
441      burn_in_iterations = 2
442      benchmark_iterations = 10
443
444      vocab_size = 1000
445      batch_size = 128
446      sequence_length = 15
447      d_embed = 200
448      d_out = 4
449
450      embed = tf.random_normal((vocab_size, d_embed))
451
452      config = _test_spinn_config(d_embed, d_out)
453      model = spinn.SNLIClassifier(config, embed)
454      trainer = spinn.SNLIClassifierTrainer(model, config.lr)
455
456      (labels, prem, prem_trans, hypo,
457       hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length,
458                                                         batch_size,
459                                                         vocab_size)
460
461      for _ in range(burn_in_iterations):
462        trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans)
463
464      gc.collect()
465      start_time = time.time()
466      for _ in xrange(benchmark_iterations):
467        trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans)
468      wall_time = time.time() - start_time
469      # Named "examples"_per_sec to conform with other benchmarks.
470      extras = {"examples_per_sec": benchmark_iterations / wall_time}
471      self.report_benchmark(
472          name="Eager_SPINN_SNLIClassifier_Benchmark",
473          iters=benchmark_iterations,
474          wall_time=wall_time,
475          extras=extras)
476
477
478if __name__ == "__main__":
479  test.main()
480