• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities of SNLI data and GloVe word vectors for SPINN model.
16
17See more details about the SNLI data set at:
18  https://nlp.stanford.edu/projects/snli/
19
20See more details about the GloVe pretrained word embeddings at:
21  https://nlp.stanford.edu/projects/glove/
22"""
23
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27
28import glob
29import math
30import os
31import random
32
33import numpy as np
34
35POSSIBLE_LABELS = ("entailment", "contradiction", "neutral")
36
37UNK_CODE = 0   # Code for unknown word tokens.
38PAD_CODE = 1   # Code for padding tokens.
39
40SHIFT_CODE = 3
41REDUCE_CODE = 2
42
43WORD_VECTOR_LEN = 300  # Embedding dimensions.
44
45LEFT_PAREN = "("
46RIGHT_PAREN = ")"
47PARENTHESES = (LEFT_PAREN, RIGHT_PAREN)
48
49
50def get_non_parenthesis_words(items):
51  """Get the non-parenthesis items from a SNLI parsed sentence.
52
53  Args:
54    items: Data items from a parsed SNLI sentence, with parentheses. E.g.,
55      ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ...
56
57  Returns:
58    A list of non-parentheses word items, all converted to lower case. E.g.,
59      ["man", "wearing", "pass", ...
60  """
61  return [x.lower() for x in items if x not in PARENTHESES and x]
62
63
64def get_shift_reduce(items):
65  """Obtain shift-reduce vector from a list of items from the SNLI data.
66
67  Args:
68    items: Data items as a list of str, e.g.,
69       ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ...
70
71  Returns:
72    A list of shift-reduce transitions, encoded as `SHIFT_CODE` for shift and
73      `REDUCE_CODE` for reduce. See code above for the values of `SHIFT_CODE`
74      and `REDUCE_CODE`.
75  """
76  trans = []
77  for item in items:
78    if item == LEFT_PAREN:
79      continue
80    elif item == RIGHT_PAREN:
81      trans.append(REDUCE_CODE)
82    else:
83      trans.append(SHIFT_CODE)
84  return trans
85
86
87def pad_and_reverse_word_ids(sentences):
88  """Pad a list of sentences to the common maximum length + 1.
89
90  Args:
91    sentences: A list of sentences as a list of list of integers. Each integer
92      is a word ID. Each list of integer corresponds to one sentence.
93
94  Returns:
95    A numpy.ndarray of shape (num_sentences, max_length + 1), wherein max_length
96      is the maximum sentence length (in # of words). Each sentence is reversed
97      and then padded with an extra one at head, as required by the model.
98  """
99  max_len = max(len(sent) for sent in sentences)
100  for sent in sentences:
101    if len(sent) < max_len:
102      sent.extend([PAD_CODE] * (max_len - len(sent)))
103  # Reverse in time order and pad an extra one.
104  sentences = np.fliplr(np.array(sentences, dtype=np.int64))
105  sentences = np.concatenate(
106      [np.ones([sentences.shape[0], 1], dtype=np.int64), sentences], axis=1)
107  return sentences
108
109
110def pad_transitions(sentences_transitions):
111  """Pad a list of shift-reduce transitions to the maximum length."""
112  max_len = max(len(transitions) for transitions in sentences_transitions)
113  for transitions in sentences_transitions:
114    if len(transitions) < max_len:
115      transitions.extend([PAD_CODE] * (max_len - len(transitions)))
116  return np.array(sentences_transitions, dtype=np.int64)
117
118
119def load_vocabulary(data_root):
120  """Load vocabulary from SNLI data files.
121
122  Args:
123    data_root: Root directory of the data. It is assumed that the SNLI data
124      files have been downloaded and extracted to the "snli/snli_1.0"
125      subdirectory of it.
126
127  Returns:
128    Vocabulary as a set of strings.
129
130  Raises:
131    ValueError: If SNLI data files cannot be found.
132  """
133  snli_path = os.path.join(data_root, "snli")
134  snli_glob_pattern = os.path.join(snli_path, "snli_1.0/snli_1.0_*.txt")
135  file_names = glob.glob(snli_glob_pattern)
136  if not file_names:
137    raise ValueError(
138        "Cannot find SNLI data files at %s. "
139        "Please download and extract SNLI data first." % snli_glob_pattern)
140
141  print("Loading vocabulary...")
142  vocab = set()
143  for file_name in file_names:
144    with open(os.path.join(snli_path, file_name), "rt") as f:
145      for i, line in enumerate(f):
146        if i == 0:
147          continue
148        items = line.split("\t")
149        premise_words = get_non_parenthesis_words(items[1].split(" "))
150        hypothesis_words = get_non_parenthesis_words(items[2].split(" "))
151        vocab.update(premise_words)
152        vocab.update(hypothesis_words)
153  return vocab
154
155
156def load_word_vectors(data_root, vocab):
157  """Load GloVe word vectors for words present in the vocabulary.
158
159  Args:
160    data_root: Data root directory. It is assumed that the GloVe file
161     has been downloaded and extracted at the "glove/" subdirectory of it.
162    vocab: A `set` of words, representing the vocabulary.
163
164  Returns:
165    1. word2index: A dict from lower-case word to row index in the embedding
166       matrix, i.e, `embed` below.
167    2. embed: The embedding matrix as a float32 numpy array. Its shape is
168       [vocabulary_size, WORD_VECTOR_LEN]. vocabulary_size is len(vocab).
169       WORD_VECTOR_LEN is the embedding dimension (300).
170
171  Raises:
172    ValueError: If GloVe embedding file cannot be found.
173  """
174  glove_path = os.path.join(data_root, "glove/glove.42B.300d.txt")
175  if not os.path.isfile(glove_path):
176    raise ValueError(
177        "Cannot find GloVe embedding file at %s. "
178        "Please download and extract GloVe embeddings first." % glove_path)
179
180  print("Loading word vectors...")
181
182  word2index = dict()
183  embed = []
184
185  embed.append([0] * WORD_VECTOR_LEN)  # <unk>
186  embed.append([0] * WORD_VECTOR_LEN)  # <pad>
187  word2index["<unk>"] = UNK_CODE
188  word2index["<pad>"] = PAD_CODE
189
190  with open(glove_path, "rt") as f:
191    for line in f:
192      items = line.split(" ")
193      word = items[0]
194      if word in vocab and word not in word2index:
195        word2index[word] = len(embed)
196        vector = np.array([float(item) for item in items[1:]])
197        assert (WORD_VECTOR_LEN,) == vector.shape
198        embed.append(vector)
199  embed = np.array(embed, dtype=np.float32)
200  return word2index, embed
201
202
203def calculate_bins(length2count, min_bin_size):
204  """Calculate bin boundaries given a histogram of lengths and minimum bin size.
205
206  Args:
207    length2count: A `dict` mapping length to sentence count.
208    min_bin_size: Minimum bin size in terms of total number of sentence pairs
209      in the bin.
210
211  Returns:
212    A `list` representing the right bin boundaries, starting from the inclusive
213    right boundary of the first bin. For example, if the output is
214      [10, 20, 35],
215    it means there are three bins: [1, 10], [11, 20] and [21, 35].
216  """
217  bounds = []
218  lengths = sorted(length2count.keys())
219  cum_count = 0
220  for length in lengths:
221    cum_count += length2count[length]
222    if cum_count >= min_bin_size:
223      bounds.append(length)
224      cum_count = 0
225  if bounds[-1] != lengths[-1]:
226    bounds.append(lengths[-1])
227  return bounds
228
229
230def encode_sentence(sentence, word2index):
231  """Encode a single sentence as word indices and shift-reduce code.
232
233  Args:
234    sentence: The sentence with added binary parse information, represented as
235      a string, with all the word items and parentheses separated by spaces.
236      E.g., '( ( The dog ) ( ( is ( playing toys ) ) . ) )'.
237    word2index: A `dict` mapping words to their word indices.
238
239  Returns:
240     1. Word indices as a numpy array, with shape `(sequence_len, 1)`.
241     2. Shift-reduce sequence as a numpy array, with shape
242       `(sequence_len * 2 - 3, 1)`.
243  """
244  items = [w for w in sentence.split(" ") if w]
245  words = get_non_parenthesis_words(items)
246  shift_reduce = get_shift_reduce(items)
247  word_indices = pad_and_reverse_word_ids(
248      [[word2index.get(word, UNK_CODE) for word in words]]).T
249  return (word_indices,
250          np.expand_dims(np.array(shift_reduce, dtype=np.int64), -1))
251
252
253class SnliData(object):
254  """A split of SNLI data."""
255
256  def __init__(self, data_file, word2index, sentence_len_limit=-1):
257    """SnliData constructor.
258
259    Args:
260      data_file: Full path to the data file, e.g.,
261        "/tmp/spinn-data/snli/snli_1.0/snli_1.0.train.txt"
262      word2index: A dict from lower-case word to row index in the embedding
263        matrix (see `load_word_vectors()` for details).
264      sentence_len_limit: Maximum allowed sentence length (# of words).
265        A value of <= 0 means unlimited. Sentences longer than this limit
266        are currently discarded, not truncated.
267    """
268
269    self._labels = []
270    self._premises = []
271    self._premise_transitions = []
272    self._hypotheses = []
273    self._hypothesis_transitions = []
274
275    with open(data_file, "rt") as f:
276      for i, line in enumerate(f):
277        if i == 0:
278          # Skip header line.
279          continue
280        items = line.split("\t")
281        if items[0] not in POSSIBLE_LABELS:
282          continue
283
284        premise_items = items[1].split(" ")
285        hypothesis_items = items[2].split(" ")
286        premise_words = get_non_parenthesis_words(premise_items)
287        hypothesis_words = get_non_parenthesis_words(hypothesis_items)
288
289        if (sentence_len_limit > 0 and
290            (len(premise_words) > sentence_len_limit or
291             len(hypothesis_words) > sentence_len_limit)):
292          # TODO(cais): Maybe truncate; do not discard.
293          continue
294
295        premise_ids = [
296            word2index.get(word, UNK_CODE) for word in premise_words]
297        hypothesis_ids = [
298            word2index.get(word, UNK_CODE) for word in hypothesis_words]
299
300        self._premises.append(premise_ids)
301        self._hypotheses.append(hypothesis_ids)
302        self._premise_transitions.append(get_shift_reduce(premise_items))
303        self._hypothesis_transitions.append(get_shift_reduce(hypothesis_items))
304        assert (len(self._premise_transitions[-1]) ==
305                2 * len(premise_words) - 1)
306        assert (len(self._hypothesis_transitions[-1]) ==
307                2 * len(hypothesis_words) - 1)
308
309        self._labels.append(POSSIBLE_LABELS.index(items[0]) + 1)
310
311    assert len(self._labels) == len(self._premises)
312    assert len(self._labels) == len(self._hypotheses)
313    assert len(self._labels) == len(self._premise_transitions)
314    assert len(self._labels) == len(self._hypothesis_transitions)
315
316  def num_batches(self, batch_size):
317    """Calculate number of batches given batch size."""
318    return int(math.ceil(len(self._labels) / batch_size))
319
320  def get_generator(self, batch_size):
321    """Obtain a generator for batched data.
322
323    All examples of this SnliData object are randomly shuffled, sorted
324    according to the maximum sentence length of the premise and hypothesis
325    sentences in the pair, and batched.
326
327    Args:
328      batch_size: Desired batch size.
329
330    Returns:
331      A generator for data batches. The generator yields a 5-tuple:
332        label: An array of the shape (batch_size,).
333        premise: An array of the shape (max_premise_len, batch_size), wherein
334          max_premise_len is the maximum length of the (padded) premise
335          sentence in the batch.
336        premise_transitions: An array of the shape (2 * max_premise_len -3,
337          batch_size).
338        hypothesis: Same as `premise`, but for hypothesis sentences.
339        hypothesis_transitions: Same as `premise_transitions`, but for
340          hypothesis sentences.
341      All the elements of the 5-tuple have dtype `int64`.
342    """
343    # Randomly shuffle examples.
344    zipped = list(zip(
345        self._labels, self._premises, self._premise_transitions,
346        self._hypotheses, self._hypothesis_transitions))
347    random.shuffle(zipped)
348    # Then sort the examples by maximum of the premise and hypothesis sentence
349    # lengths in the pair. During training, the batches are expected to be
350    # shuffled. So it is okay to leave them sorted by max length here.
351    (labels, premises, premise_transitions, hypotheses,
352     hypothesis_transitions) = zip(
353         *sorted(zipped, key=lambda x: max(len(x[1]), len(x[3]))))
354
355    def _generator():
356      begin = 0
357      while begin < len(labels):
358        # The sorting above and the batching here makes sure that sentences of
359        # similar max lengths are batched together, minimizing the inefficiency
360        # due to uneven max lengths. The sentences are batched differently in
361        # each call to get_generator() due to the shuffling before sorting
362        # above. The pad_and_reverse_word_ids() and pad_transitions() functions
363        # take care of any remaining unevenness of the max sentence lengths.
364        end = min(begin + batch_size, len(labels))
365        # Transpose, because the SPINN model requires time-major, instead of
366        # batch-major.
367        yield (labels[begin:end],
368               pad_and_reverse_word_ids(premises[begin:end]).T,
369               pad_transitions(premise_transitions[begin:end]).T,
370               pad_and_reverse_word_ids(hypotheses[begin:end]).T,
371               pad_transitions(hypothesis_transitions[begin:end]).T)
372        begin = end
373    return _generator
374