1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Skip-gram sampling ops from https://arxiv.org/abs/1301.3781.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import csv 21 22from tensorflow.contrib import lookup 23from tensorflow.contrib.text.python.ops import gen_skip_gram_ops 24from tensorflow.contrib.util import loader 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import random_seed 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import random_ops 31from tensorflow.python.platform import gfile 32from tensorflow.python.platform import resource_loader 33from tensorflow.python.training import input as input_ops 34 35_checkpoint_ops_so = loader.load_op_library( 36 resource_loader.get_path_to_datafile("_skip_gram_ops.so")) 37 38ops.NotDifferentiable("SkipGramGenerateCandidates") 39 40 41def skip_gram_sample(input_tensor, 42 min_skips=1, 43 max_skips=5, 44 start=0, 45 limit=-1, 46 emit_self_as_target=False, 47 vocab_freq_table=None, 48 vocab_min_count=None, 49 vocab_subsampling=None, 50 corpus_size=None, 51 batch_size=None, 52 batch_capacity=None, 53 seed=None, 54 name=None): 55 """Generates skip-gram token and label paired Tensors from the input tensor. 56 57 Generates skip-gram `("token", "label")` pairs using each element in the 58 rank-1 `input_tensor` as a token. The window size used for each token will be 59 randomly selected from the range specified by `[min_skips, max_skips]`, 60 inclusive. See https://arxiv.org/abs/1301.3781 for more details about 61 skip-gram. 62 63 For example, given `input_tensor = ["the", "quick", "brown", "fox", "jumps"]`, 64 `min_skips = 1`, `max_skips = 2`, `emit_self_as_target = False`, the output 65 `(tokens, labels)` pairs for the token "quick" will be randomly selected from 66 either `(tokens=["quick", "quick"], labels=["the", "brown"])` for 1 skip, or 67 `(tokens=["quick", "quick", "quick"], labels=["the", "brown", "fox"])` for 2 68 skips. 69 70 If `emit_self_as_target = True`, each token will also be emitted as a label 71 for itself. From the previous example, the output will be either 72 `(tokens=["quick", "quick", "quick"], labels=["the", "quick", "brown"])` for 1 73 skip, or `(tokens=["quick", "quick", "quick", "quick"], labels=["the", 74 "quick", "brown", "fox"])` for 2 skips. 75 76 The same process is repeated for each element of `input_tensor` and 77 concatenated together into the two output rank-1 `Tensors` (one for all the 78 tokens, another for all the labels). 79 80 If `vocab_freq_table` is specified, tokens in `input_tensor` that are not 81 present in the vocabulary are discarded. Tokens whose frequency counts are 82 below `vocab_min_count` are also discarded. Tokens whose frequency proportions 83 in the corpus exceed `vocab_subsampling` may be randomly down-sampled. See 84 Eq. 5 in http://arxiv.org/abs/1310.4546 for more details about subsampling. 85 86 Due to the random window sizes used for each token, the lengths of the outputs 87 are non-deterministic, unless `batch_size` is specified to batch the outputs 88 to always return `Tensors` of length `batch_size`. 89 90 Args: 91 input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates. 92 min_skips: `int` or scalar `Tensor` specifying the minimum window size to 93 randomly use for each token. Must be >= 0 and <= `max_skips`. If 94 `min_skips` and `max_skips` are both 0, the only label outputted will be 95 the token itself when `emit_self_as_target = True` - or no output 96 otherwise. 97 max_skips: `int` or scalar `Tensor` specifying the maximum window size to 98 randomly use for each token. Must be >= 0. 99 start: `int` or scalar `Tensor` specifying the position in 100 `input_tensor` from which to start generating skip-gram candidates. 101 limit: `int` or scalar `Tensor` specifying the maximum number of 102 elements in `input_tensor` to use in generating skip-gram candidates. -1 103 means to use the rest of the `Tensor` after `start`. 104 emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit 105 each token as a label for itself. 106 vocab_freq_table: (Optional) A lookup table (subclass of 107 `lookup.InitializableLookupTableBase`) that maps tokens to their raw 108 frequency counts. If specified, any token in `input_tensor` that is not 109 found in `vocab_freq_table` will be filtered out before generating 110 skip-gram candidates. While this will typically map to integer raw 111 frequency counts, it could also map to float frequency proportions. 112 `vocab_min_count` and `corpus_size` should be in the same units as this. 113 vocab_min_count: (Optional) `int`, `float`, or scalar `Tensor` specifying 114 minimum frequency threshold (from `vocab_freq_table`) for a token to be 115 kept in `input_tensor`. If this is specified, `vocab_freq_table` must also 116 be specified - and they should both be in the same units. 117 vocab_subsampling: (Optional) `float` specifying frequency proportion 118 threshold for tokens from `input_tensor`. Tokens that occur more 119 frequently (based on the ratio of the token's `vocab_freq_table` value to 120 the `corpus_size`) will be randomly down-sampled. Reasonable starting 121 values may be around 1e-3 or 1e-5. If this is specified, both 122 `vocab_freq_table` and `corpus_size` must also be specified. See Eq. 5 123 in http://arxiv.org/abs/1310.4546 for more details. 124 corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the 125 total number of tokens in the corpus (e.g., sum of all the frequency 126 counts of `vocab_freq_table`). Used with `vocab_subsampling` for 127 down-sampling frequently occurring tokens. If this is specified, 128 `vocab_freq_table` and `vocab_subsampling` must also be specified. 129 batch_size: (Optional) `int` specifying batch size of returned `Tensors`. 130 batch_capacity: (Optional) `int` specifying batch capacity for the queue 131 used for batching returned `Tensors`. Only has an effect if 132 `batch_size` > 0. Defaults to 100 * `batch_size` if not specified. 133 seed: (Optional) `int` used to create a random seed for window size and 134 subsampling. See `set_random_seed` docs for behavior. 135 name: (Optional) A `string` name or a name scope for the operations. 136 137 Returns: 138 A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of 139 rank-1 and has the same type as `input_tensor`. The `Tensors` will be of 140 length `batch_size`; if `batch_size` is not specified, they will be of 141 random length, though they will be in sync with each other as long as they 142 are evaluated together. 143 144 Raises: 145 ValueError: If `vocab_freq_table` is not provided, but `vocab_min_count`, 146 `vocab_subsampling`, or `corpus_size` is specified. If `vocab_subsampling` 147 and `corpus_size` are not both present or both absent. 148 """ 149 150 if vocab_freq_table is None and (vocab_min_count is not None or 151 vocab_subsampling is not None or 152 corpus_size is not None): 153 raise ValueError( 154 "vocab_freq_table is not provided, but vocab_min_count={}, " 155 "vocab_subsampling={}, or corpus_size={} is not None. These settings " 156 "are useless without a vocab_freq_table.".format( 157 vocab_min_count, vocab_subsampling, corpus_size)) 158 159 if (vocab_subsampling is None) != (corpus_size is None): 160 raise ValueError( 161 "vocab_subsampling is {} while corpus_size is {} - both must be " 162 "provided in order for subsampling to work.".format( 163 vocab_subsampling, corpus_size)) 164 165 with ops.name_scope( 166 name, 167 "skip_gram_sample", 168 values=[input_tensor, min_skips, max_skips, start, limit]): 169 170 input_tensor = _filter_input( 171 input_tensor=input_tensor, 172 vocab_freq_table=vocab_freq_table, 173 vocab_min_count=vocab_min_count, 174 vocab_subsampling=vocab_subsampling, 175 corpus_size=corpus_size, 176 seed=seed) 177 178 seed1, seed2 = random_seed.get_seed(seed) 179 tokens, labels = gen_skip_gram_ops.skip_gram_generate_candidates( 180 input_tensor=input_tensor, 181 min_skips=min_skips, 182 max_skips=max_skips, 183 start=start, 184 limit=limit, 185 emit_self_as_target=emit_self_as_target, 186 # Note that seed here should be seed1! This is due to 187 # GuardedPhiloxRandom's hard-coded attributes of "seed" and "seed2". 188 seed=seed1, 189 seed2=seed2) 190 191 # TODO(weiho): If the need arises, add support for sparse input_tensor that 192 # figures out sentence boundaries, then calls 193 # skip_gram_generate_candidates() on each sentence. 194 195 # Batches the (tokens, labels) outputs so that they will be of deterministic 196 # batch_size, to facilitate feeding them into the rest of the network. 197 if batch_size is not None and batch_size > 0: 198 batch_capacity = (batch_capacity 199 if (batch_capacity is not None and batch_capacity > 0) 200 else 100 * batch_size) 201 return input_ops.batch( 202 [tokens, labels], 203 batch_size, 204 capacity=batch_capacity, 205 enqueue_many=True) 206 207 return tokens, labels 208 209 210def skip_gram_sample_with_text_vocab(input_tensor, 211 vocab_freq_file, 212 vocab_token_index=0, 213 vocab_token_dtype=dtypes.string, 214 vocab_freq_index=1, 215 vocab_freq_dtype=dtypes.float64, 216 vocab_delimiter=",", 217 vocab_min_count=0, 218 vocab_subsampling=None, 219 corpus_size=None, 220 min_skips=1, 221 max_skips=5, 222 start=0, 223 limit=-1, 224 emit_self_as_target=False, 225 batch_size=None, 226 batch_capacity=None, 227 seed=None, 228 name=None): 229 """Skip-gram sampling with a text vocabulary file. 230 231 Wrapper around `skip_gram_sample()` for use with a text vocabulary file. The 232 vocabulary file is expected to be a plain-text file, with lines of 233 `vocab_delimiter`-separated columns. The `vocab_token_index` column should 234 contain the vocabulary term, while the `vocab_freq_index` column should 235 contain the number of times that term occurs in the corpus. For example, with 236 a text vocabulary file of: 237 238 ``` 239 bonjour,fr,42 240 hello,en,777 241 hola,es,99 242 ``` 243 244 You should set `vocab_delimiter=","`, `vocab_token_index=0`, and 245 `vocab_freq_index=2`. 246 247 See `skip_gram_sample()` documentation for more details about the skip-gram 248 sampling process. 249 250 Args: 251 input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates. 252 vocab_freq_file: `string` specifying full file path to the text vocab file. 253 vocab_token_index: `int` specifying which column in the text vocab file 254 contains the tokens. 255 vocab_token_dtype: `DType` specifying the format of the tokens in the text 256 vocab file. 257 vocab_freq_index: `int` specifying which column in the text vocab file 258 contains the frequency counts of the tokens. 259 vocab_freq_dtype: `DType` specifying the format of the frequency counts in 260 the text vocab file. 261 vocab_delimiter: `string` specifying the delimiter used in the text vocab 262 file. 263 vocab_min_count: `int`, `float`, or scalar `Tensor` specifying 264 minimum frequency threshold (from `vocab_freq_file`) for a token to be 265 kept in `input_tensor`. This should correspond with `vocab_freq_dtype`. 266 vocab_subsampling: (Optional) `float` specifying frequency proportion 267 threshold for tokens from `input_tensor`. Tokens that occur more 268 frequently will be randomly down-sampled. Reasonable starting values may 269 be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 for 270 more details. 271 corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the 272 total number of tokens in the corpus (e.g., sum of all the frequency 273 counts of `vocab_freq_file`). Used with `vocab_subsampling` for 274 down-sampling frequently occurring tokens. If this is specified, 275 `vocab_freq_file` and `vocab_subsampling` must also be specified. 276 If `corpus_size` is needed but not supplied, then it will be calculated 277 from `vocab_freq_file`. You might want to supply your own value if you 278 have already eliminated infrequent tokens from your vocabulary files 279 (where frequency < vocab_min_count) to save memory in the internal token 280 lookup table. Otherwise, the unused tokens' variables will waste memory. 281 The user-supplied `corpus_size` value must be greater than or equal to the 282 sum of all the frequency counts of `vocab_freq_file`. 283 min_skips: `int` or scalar `Tensor` specifying the minimum window size to 284 randomly use for each token. Must be >= 0 and <= `max_skips`. If 285 `min_skips` and `max_skips` are both 0, the only label outputted will be 286 the token itself. 287 max_skips: `int` or scalar `Tensor` specifying the maximum window size to 288 randomly use for each token. Must be >= 0. 289 start: `int` or scalar `Tensor` specifying the position in `input_tensor` 290 from which to start generating skip-gram candidates. 291 limit: `int` or scalar `Tensor` specifying the maximum number of elements in 292 `input_tensor` to use in generating skip-gram candidates. -1 means to use 293 the rest of the `Tensor` after `start`. 294 emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit 295 each token as a label for itself. 296 batch_size: (Optional) `int` specifying batch size of returned `Tensors`. 297 batch_capacity: (Optional) `int` specifying batch capacity for the queue 298 used for batching returned `Tensors`. Only has an effect if 299 `batch_size` > 0. Defaults to 100 * `batch_size` if not specified. 300 seed: (Optional) `int` used to create a random seed for window size and 301 subsampling. See 302 [`set_random_seed`](../../g3doc/python/constant_op.md#set_random_seed) 303 for behavior. 304 name: (Optional) A `string` name or a name scope for the operations. 305 306 Returns: 307 A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of 308 rank-1 and has the same type as `input_tensor`. The `Tensors` will be of 309 length `batch_size`; if `batch_size` is not specified, they will be of 310 random length, though they will be in sync with each other as long as they 311 are evaluated together. 312 313 Raises: 314 ValueError: If `vocab_token_index` or `vocab_freq_index` is less than 0 or 315 exceeds the number of columns in `vocab_freq_file`. If `vocab_token_index` 316 and `vocab_freq_index` are both set to the same column. If any token in 317 `vocab_freq_file` has a negative frequency. 318 """ 319 320 if vocab_token_index < 0 or vocab_freq_index < 0: 321 raise ValueError( 322 "vocab_token_index={} and vocab_freq_index={} must both be >= 0.". 323 format(vocab_token_index, vocab_freq_index)) 324 if vocab_token_index == vocab_freq_index: 325 raise ValueError( 326 "vocab_token_index and vocab_freq_index should be different, but are " 327 "both {}.".format(vocab_token_index)) 328 329 # Iterates through the vocab file and calculates the number of vocab terms as 330 # well as the total corpus size (by summing the frequency counts of all the 331 # vocab terms). 332 calculated_corpus_size = 0.0 333 vocab_size = 0 334 with gfile.GFile(vocab_freq_file, mode="r") as f: 335 reader = csv.reader(f, delimiter=vocab_delimiter) 336 for row in reader: 337 if vocab_token_index >= len(row) or vocab_freq_index >= len(row): 338 raise ValueError( 339 "Row in vocab file only has {} columns, so vocab_token_index={} or " 340 "vocab_freq_index={} is out of bounds. Row content: {}".format( 341 len(row), vocab_token_index, vocab_freq_index, row)) 342 vocab_size += 1 343 freq = vocab_freq_dtype.as_numpy_dtype(row[vocab_freq_index]) 344 if freq < 0: 345 raise ValueError( 346 "Row in vocab file has negative frequency of {}. Row content: {}". 347 format(freq, row)) 348 # Note: tokens whose frequencies are below vocab_min_count will still 349 # contribute to the total corpus size used for vocab subsampling. 350 calculated_corpus_size += freq 351 352 if not corpus_size: 353 corpus_size = calculated_corpus_size 354 elif calculated_corpus_size - corpus_size > 1e-6: 355 raise ValueError( 356 "`corpus_size`={} must be greater than or equal to the sum of all the " 357 "frequency counts ({}) of `vocab_freq_file` ({}).".format( 358 corpus_size, calculated_corpus_size, vocab_freq_file)) 359 360 vocab_freq_table = lookup.HashTable( 361 lookup.TextFileInitializer( 362 filename=vocab_freq_file, 363 key_dtype=vocab_token_dtype, 364 key_index=vocab_token_index, 365 value_dtype=vocab_freq_dtype, 366 value_index=vocab_freq_index, 367 vocab_size=vocab_size, 368 delimiter=vocab_delimiter), 369 # For vocab terms not in vocab file, use a default value of -1. 370 default_value=-1) 371 372 return skip_gram_sample( 373 input_tensor, 374 min_skips=min_skips, 375 max_skips=max_skips, 376 start=start, 377 limit=limit, 378 emit_self_as_target=emit_self_as_target, 379 vocab_freq_table=vocab_freq_table, 380 vocab_min_count=vocab_min_count, 381 vocab_subsampling=vocab_subsampling, 382 # corpus_size is not used unless vocab_subsampling is specified. 383 corpus_size=None if vocab_subsampling is None else corpus_size, 384 batch_size=batch_size, 385 batch_capacity=batch_capacity, 386 seed=seed, 387 name=name) 388 389 390def _filter_input(input_tensor, vocab_freq_table, vocab_min_count, 391 vocab_subsampling, corpus_size, seed): 392 """Filters input tensor based on vocab freq, threshold, and subsampling.""" 393 if vocab_freq_table is None: 394 return input_tensor 395 396 if not isinstance(vocab_freq_table, lookup.InitializableLookupTableBase): 397 raise ValueError( 398 "vocab_freq_table must be a subclass of " 399 "InitializableLookupTableBase (such as HashTable) instead of type " 400 "{}.".format(type(vocab_freq_table))) 401 402 with ops.name_scope( 403 "filter_vocab", values=[vocab_freq_table, input_tensor, vocab_min_count]): 404 freq = vocab_freq_table.lookup(input_tensor) 405 # Filters out elements in input_tensor that are not found in 406 # vocab_freq_table (table returns a default value of -1 specified above when 407 # an element is not found). 408 mask = math_ops.not_equal(freq, vocab_freq_table.default_value) 409 410 # Filters out elements whose vocab frequencies are less than the threshold. 411 if vocab_min_count is not None: 412 cast_threshold = math_ops.cast(vocab_min_count, freq.dtype) 413 mask = math_ops.logical_and(mask, 414 math_ops.greater_equal(freq, cast_threshold)) 415 416 input_tensor = array_ops.boolean_mask(input_tensor, mask) 417 freq = array_ops.boolean_mask(freq, mask) 418 419 if not vocab_subsampling: 420 return input_tensor 421 422 if vocab_subsampling < 0 or vocab_subsampling > 1: 423 raise ValueError( 424 "Invalid vocab_subsampling={} - it should be within range [0, 1].". 425 format(vocab_subsampling)) 426 427 # Subsamples the input tokens based on vocabulary frequency and 428 # vocab_subsampling threshold (ie randomly discard commonly appearing 429 # tokens). 430 with ops.name_scope( 431 "subsample_vocab", values=[input_tensor, freq, vocab_subsampling]): 432 corpus_size = math_ops.cast(corpus_size, dtypes.float64) 433 freq = math_ops.cast(freq, dtypes.float64) 434 vocab_subsampling = math_ops.cast(vocab_subsampling, dtypes.float64) 435 436 # From tensorflow_models/tutorials/embedding/word2vec_kernels.cc, which is 437 # suppose to correlate with Eq. 5 in http://arxiv.org/abs/1310.4546. 438 keep_prob = ((math_ops.sqrt(freq / 439 (vocab_subsampling * corpus_size)) + 1.0) * 440 (vocab_subsampling * corpus_size / freq)) 441 random_prob = random_ops.random_uniform( 442 array_ops.shape(freq), 443 minval=0, 444 maxval=1, 445 dtype=dtypes.float64, 446 seed=seed) 447 448 mask = math_ops.less_equal(random_prob, keep_prob) 449 return array_ops.boolean_mask(input_tensor, mask) 450