1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities of SNLI data and GloVe word vectors for SPINN model. 16 17See more details about the SNLI data set at: 18 https://nlp.stanford.edu/projects/snli/ 19 20See more details about the GloVe pretrained word embeddings at: 21 https://nlp.stanford.edu/projects/glove/ 22""" 23 24from __future__ import absolute_import 25from __future__ import division 26from __future__ import print_function 27 28import glob 29import math 30import os 31import random 32 33import numpy as np 34 35POSSIBLE_LABELS = ("entailment", "contradiction", "neutral") 36 37UNK_CODE = 0 # Code for unknown word tokens. 38PAD_CODE = 1 # Code for padding tokens. 39 40SHIFT_CODE = 3 41REDUCE_CODE = 2 42 43WORD_VECTOR_LEN = 300 # Embedding dimensions. 44 45LEFT_PAREN = "(" 46RIGHT_PAREN = ")" 47PARENTHESES = (LEFT_PAREN, RIGHT_PAREN) 48 49 50def get_non_parenthesis_words(items): 51 """Get the non-parenthesis items from a SNLI parsed sentence. 52 53 Args: 54 items: Data items from a parsed SNLI sentence, with parentheses. E.g., 55 ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... 56 57 Returns: 58 A list of non-parentheses word items, all converted to lower case. E.g., 59 ["man", "wearing", "pass", ... 60 """ 61 return [x.lower() for x in items if x not in PARENTHESES and x] 62 63 64def get_shift_reduce(items): 65 """Obtain shift-reduce vector from a list of items from the SNLI data. 66 67 Args: 68 items: Data items as a list of str, e.g., 69 ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... 70 71 Returns: 72 A list of shift-reduce transitions, encoded as `SHIFT_CODE` for shift and 73 `REDUCE_CODE` for reduce. See code above for the values of `SHIFT_CODE` 74 and `REDUCE_CODE`. 75 """ 76 trans = [] 77 for item in items: 78 if item == LEFT_PAREN: 79 continue 80 elif item == RIGHT_PAREN: 81 trans.append(REDUCE_CODE) 82 else: 83 trans.append(SHIFT_CODE) 84 return trans 85 86 87def pad_and_reverse_word_ids(sentences): 88 """Pad a list of sentences to the common maximum length + 1. 89 90 Args: 91 sentences: A list of sentences as a list of list of integers. Each integer 92 is a word ID. Each list of integer corresponds to one sentence. 93 94 Returns: 95 A numpy.ndarray of shape (num_sentences, max_length + 1), wherein max_length 96 is the maximum sentence length (in # of words). Each sentence is reversed 97 and then padded with an extra one at head, as required by the model. 98 """ 99 max_len = max(len(sent) for sent in sentences) 100 for sent in sentences: 101 if len(sent) < max_len: 102 sent.extend([PAD_CODE] * (max_len - len(sent))) 103 # Reverse in time order and pad an extra one. 104 sentences = np.fliplr(np.array(sentences, dtype=np.int64)) 105 sentences = np.concatenate( 106 [np.ones([sentences.shape[0], 1], dtype=np.int64), sentences], axis=1) 107 return sentences 108 109 110def pad_transitions(sentences_transitions): 111 """Pad a list of shift-reduce transitions to the maximum length.""" 112 max_len = max(len(transitions) for transitions in sentences_transitions) 113 for transitions in sentences_transitions: 114 if len(transitions) < max_len: 115 transitions.extend([PAD_CODE] * (max_len - len(transitions))) 116 return np.array(sentences_transitions, dtype=np.int64) 117 118 119def load_vocabulary(data_root): 120 """Load vocabulary from SNLI data files. 121 122 Args: 123 data_root: Root directory of the data. It is assumed that the SNLI data 124 files have been downloaded and extracted to the "snli/snli_1.0" 125 subdirectory of it. 126 127 Returns: 128 Vocabulary as a set of strings. 129 130 Raises: 131 ValueError: If SNLI data files cannot be found. 132 """ 133 snli_path = os.path.join(data_root, "snli") 134 snli_glob_pattern = os.path.join(snli_path, "snli_1.0/snli_1.0_*.txt") 135 file_names = glob.glob(snli_glob_pattern) 136 if not file_names: 137 raise ValueError( 138 "Cannot find SNLI data files at %s. " 139 "Please download and extract SNLI data first." % snli_glob_pattern) 140 141 print("Loading vocabulary...") 142 vocab = set() 143 for file_name in file_names: 144 with open(os.path.join(snli_path, file_name), "rt") as f: 145 for i, line in enumerate(f): 146 if i == 0: 147 continue 148 items = line.split("\t") 149 premise_words = get_non_parenthesis_words(items[1].split(" ")) 150 hypothesis_words = get_non_parenthesis_words(items[2].split(" ")) 151 vocab.update(premise_words) 152 vocab.update(hypothesis_words) 153 return vocab 154 155 156def load_word_vectors(data_root, vocab): 157 """Load GloVe word vectors for words present in the vocabulary. 158 159 Args: 160 data_root: Data root directory. It is assumed that the GloVe file 161 has been downloaded and extracted at the "glove/" subdirectory of it. 162 vocab: A `set` of words, representing the vocabulary. 163 164 Returns: 165 1. word2index: A dict from lower-case word to row index in the embedding 166 matrix, i.e, `embed` below. 167 2. embed: The embedding matrix as a float32 numpy array. Its shape is 168 [vocabulary_size, WORD_VECTOR_LEN]. vocabulary_size is len(vocab). 169 WORD_VECTOR_LEN is the embedding dimension (300). 170 171 Raises: 172 ValueError: If GloVe embedding file cannot be found. 173 """ 174 glove_path = os.path.join(data_root, "glove/glove.42B.300d.txt") 175 if not os.path.isfile(glove_path): 176 raise ValueError( 177 "Cannot find GloVe embedding file at %s. " 178 "Please download and extract GloVe embeddings first." % glove_path) 179 180 print("Loading word vectors...") 181 182 word2index = dict() 183 embed = [] 184 185 embed.append([0] * WORD_VECTOR_LEN) # <unk> 186 embed.append([0] * WORD_VECTOR_LEN) # <pad> 187 word2index["<unk>"] = UNK_CODE 188 word2index["<pad>"] = PAD_CODE 189 190 with open(glove_path, "rt") as f: 191 for line in f: 192 items = line.split(" ") 193 word = items[0] 194 if word in vocab and word not in word2index: 195 word2index[word] = len(embed) 196 vector = np.array([float(item) for item in items[1:]]) 197 assert (WORD_VECTOR_LEN,) == vector.shape 198 embed.append(vector) 199 embed = np.array(embed, dtype=np.float32) 200 return word2index, embed 201 202 203def calculate_bins(length2count, min_bin_size): 204 """Calculate bin boundaries given a histogram of lengths and minimum bin size. 205 206 Args: 207 length2count: A `dict` mapping length to sentence count. 208 min_bin_size: Minimum bin size in terms of total number of sentence pairs 209 in the bin. 210 211 Returns: 212 A `list` representing the right bin boundaries, starting from the inclusive 213 right boundary of the first bin. For example, if the output is 214 [10, 20, 35], 215 it means there are three bins: [1, 10], [11, 20] and [21, 35]. 216 """ 217 bounds = [] 218 lengths = sorted(length2count.keys()) 219 cum_count = 0 220 for length in lengths: 221 cum_count += length2count[length] 222 if cum_count >= min_bin_size: 223 bounds.append(length) 224 cum_count = 0 225 if bounds[-1] != lengths[-1]: 226 bounds.append(lengths[-1]) 227 return bounds 228 229 230def encode_sentence(sentence, word2index): 231 """Encode a single sentence as word indices and shift-reduce code. 232 233 Args: 234 sentence: The sentence with added binary parse information, represented as 235 a string, with all the word items and parentheses separated by spaces. 236 E.g., '( ( The dog ) ( ( is ( playing toys ) ) . ) )'. 237 word2index: A `dict` mapping words to their word indices. 238 239 Returns: 240 1. Word indices as a numpy array, with shape `(sequence_len, 1)`. 241 2. Shift-reduce sequence as a numpy array, with shape 242 `(sequence_len * 2 - 3, 1)`. 243 """ 244 items = [w for w in sentence.split(" ") if w] 245 words = get_non_parenthesis_words(items) 246 shift_reduce = get_shift_reduce(items) 247 word_indices = pad_and_reverse_word_ids( 248 [[word2index.get(word, UNK_CODE) for word in words]]).T 249 return (word_indices, 250 np.expand_dims(np.array(shift_reduce, dtype=np.int64), -1)) 251 252 253class SnliData(object): 254 """A split of SNLI data.""" 255 256 def __init__(self, data_file, word2index, sentence_len_limit=-1): 257 """SnliData constructor. 258 259 Args: 260 data_file: Full path to the data file, e.g., 261 "/tmp/spinn-data/snli/snli_1.0/snli_1.0.train.txt" 262 word2index: A dict from lower-case word to row index in the embedding 263 matrix (see `load_word_vectors()` for details). 264 sentence_len_limit: Maximum allowed sentence length (# of words). 265 A value of <= 0 means unlimited. Sentences longer than this limit 266 are currently discarded, not truncated. 267 """ 268 269 self._labels = [] 270 self._premises = [] 271 self._premise_transitions = [] 272 self._hypotheses = [] 273 self._hypothesis_transitions = [] 274 275 with open(data_file, "rt") as f: 276 for i, line in enumerate(f): 277 if i == 0: 278 # Skip header line. 279 continue 280 items = line.split("\t") 281 if items[0] not in POSSIBLE_LABELS: 282 continue 283 284 premise_items = items[1].split(" ") 285 hypothesis_items = items[2].split(" ") 286 premise_words = get_non_parenthesis_words(premise_items) 287 hypothesis_words = get_non_parenthesis_words(hypothesis_items) 288 289 if (sentence_len_limit > 0 and 290 (len(premise_words) > sentence_len_limit or 291 len(hypothesis_words) > sentence_len_limit)): 292 # TODO(cais): Maybe truncate; do not discard. 293 continue 294 295 premise_ids = [ 296 word2index.get(word, UNK_CODE) for word in premise_words] 297 hypothesis_ids = [ 298 word2index.get(word, UNK_CODE) for word in hypothesis_words] 299 300 self._premises.append(premise_ids) 301 self._hypotheses.append(hypothesis_ids) 302 self._premise_transitions.append(get_shift_reduce(premise_items)) 303 self._hypothesis_transitions.append(get_shift_reduce(hypothesis_items)) 304 assert (len(self._premise_transitions[-1]) == 305 2 * len(premise_words) - 1) 306 assert (len(self._hypothesis_transitions[-1]) == 307 2 * len(hypothesis_words) - 1) 308 309 self._labels.append(POSSIBLE_LABELS.index(items[0]) + 1) 310 311 assert len(self._labels) == len(self._premises) 312 assert len(self._labels) == len(self._hypotheses) 313 assert len(self._labels) == len(self._premise_transitions) 314 assert len(self._labels) == len(self._hypothesis_transitions) 315 316 def num_batches(self, batch_size): 317 """Calculate number of batches given batch size.""" 318 return int(math.ceil(len(self._labels) / batch_size)) 319 320 def get_generator(self, batch_size): 321 """Obtain a generator for batched data. 322 323 All examples of this SnliData object are randomly shuffled, sorted 324 according to the maximum sentence length of the premise and hypothesis 325 sentences in the pair, and batched. 326 327 Args: 328 batch_size: Desired batch size. 329 330 Returns: 331 A generator for data batches. The generator yields a 5-tuple: 332 label: An array of the shape (batch_size,). 333 premise: An array of the shape (max_premise_len, batch_size), wherein 334 max_premise_len is the maximum length of the (padded) premise 335 sentence in the batch. 336 premise_transitions: An array of the shape (2 * max_premise_len -3, 337 batch_size). 338 hypothesis: Same as `premise`, but for hypothesis sentences. 339 hypothesis_transitions: Same as `premise_transitions`, but for 340 hypothesis sentences. 341 All the elements of the 5-tuple have dtype `int64`. 342 """ 343 # Randomly shuffle examples. 344 zipped = list(zip( 345 self._labels, self._premises, self._premise_transitions, 346 self._hypotheses, self._hypothesis_transitions)) 347 random.shuffle(zipped) 348 # Then sort the examples by maximum of the premise and hypothesis sentence 349 # lengths in the pair. During training, the batches are expected to be 350 # shuffled. So it is okay to leave them sorted by max length here. 351 (labels, premises, premise_transitions, hypotheses, 352 hypothesis_transitions) = zip( 353 *sorted(zipped, key=lambda x: max(len(x[1]), len(x[3])))) 354 355 def _generator(): 356 begin = 0 357 while begin < len(labels): 358 # The sorting above and the batching here makes sure that sentences of 359 # similar max lengths are batched together, minimizing the inefficiency 360 # due to uneven max lengths. The sentences are batched differently in 361 # each call to get_generator() due to the shuffling before sorting 362 # above. The pad_and_reverse_word_ids() and pad_transitions() functions 363 # take care of any remaining unevenness of the max sentence lengths. 364 end = min(begin + batch_size, len(labels)) 365 # Transpose, because the SPINN model requires time-major, instead of 366 # batch-major. 367 yield (labels[begin:end], 368 pad_and_reverse_word_ids(premises[begin:end]).T, 369 pad_transitions(premise_transitions[begin:end]).T, 370 pad_and_reverse_word_ids(hypotheses[begin:end]).T, 371 pad_transitions(hypothesis_transitions[begin:end]).T) 372 begin = end 373 return _generator 374