1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Skip-gram sampling ops tests.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import csv 21import os 22 23from tensorflow.contrib import lookup 24from tensorflow.contrib import text 25from tensorflow.contrib.text.python.ops import skip_gram_ops 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import random_seed 30from tensorflow.python.ops import lookup_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.platform import test 33from tensorflow.python.training import coordinator 34from tensorflow.python.training import queue_runner_impl 35 36 37class SkipGramOpsTest(test.TestCase): 38 39 def _split_tokens_labels(self, output): 40 tokens = [x[0] for x in output] 41 labels = [x[1] for x in output] 42 return tokens, labels 43 44 def test_skip_gram_sample_skips_2(self): 45 """Tests skip-gram with min_skips = max_skips = 2.""" 46 input_tensor = constant_op.constant( 47 [b"the", b"quick", b"brown", b"fox", b"jumps"]) 48 tokens, labels = text.skip_gram_sample( 49 input_tensor, min_skips=2, max_skips=2) 50 expected_tokens, expected_labels = self._split_tokens_labels([ 51 (b"the", b"quick"), 52 (b"the", b"brown"), 53 (b"quick", b"the"), 54 (b"quick", b"brown"), 55 (b"quick", b"fox"), 56 (b"brown", b"the"), 57 (b"brown", b"quick"), 58 (b"brown", b"fox"), 59 (b"brown", b"jumps"), 60 (b"fox", b"quick"), 61 (b"fox", b"brown"), 62 (b"fox", b"jumps"), 63 (b"jumps", b"brown"), 64 (b"jumps", b"fox"), 65 ]) 66 with self.cached_session(): 67 self.assertAllEqual(expected_tokens, tokens.eval()) 68 self.assertAllEqual(expected_labels, labels.eval()) 69 70 def test_skip_gram_sample_emit_self(self): 71 """Tests skip-gram with emit_self_as_target = True.""" 72 input_tensor = constant_op.constant( 73 [b"the", b"quick", b"brown", b"fox", b"jumps"]) 74 tokens, labels = text.skip_gram_sample( 75 input_tensor, min_skips=2, max_skips=2, emit_self_as_target=True) 76 expected_tokens, expected_labels = self._split_tokens_labels([ 77 (b"the", b"the"), 78 (b"the", b"quick"), 79 (b"the", b"brown"), 80 (b"quick", b"the"), 81 (b"quick", b"quick"), 82 (b"quick", b"brown"), 83 (b"quick", b"fox"), 84 (b"brown", b"the"), 85 (b"brown", b"quick"), 86 (b"brown", b"brown"), 87 (b"brown", b"fox"), 88 (b"brown", b"jumps"), 89 (b"fox", b"quick"), 90 (b"fox", b"brown"), 91 (b"fox", b"fox"), 92 (b"fox", b"jumps"), 93 (b"jumps", b"brown"), 94 (b"jumps", b"fox"), 95 (b"jumps", b"jumps"), 96 ]) 97 with self.cached_session(): 98 self.assertAllEqual(expected_tokens, tokens.eval()) 99 self.assertAllEqual(expected_labels, labels.eval()) 100 101 def test_skip_gram_sample_skips_0(self): 102 """Tests skip-gram with min_skips = max_skips = 0.""" 103 input_tensor = constant_op.constant([b"the", b"quick", b"brown"]) 104 105 # If emit_self_as_target is False (default), output will be empty. 106 tokens, labels = text.skip_gram_sample( 107 input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False) 108 with self.cached_session(): 109 self.assertEqual(0, tokens.eval().size) 110 self.assertEqual(0, labels.eval().size) 111 112 # If emit_self_as_target is True, each token will be its own label. 113 tokens, labels = text.skip_gram_sample( 114 input_tensor, min_skips=0, max_skips=0, emit_self_as_target=True) 115 expected_tokens, expected_labels = self._split_tokens_labels([ 116 (b"the", b"the"), 117 (b"quick", b"quick"), 118 (b"brown", b"brown"), 119 ]) 120 with self.cached_session(): 121 self.assertAllEqual(expected_tokens, tokens.eval()) 122 self.assertAllEqual(expected_labels, labels.eval()) 123 124 def test_skip_gram_sample_skips_exceed_length(self): 125 """Tests skip-gram when min/max_skips exceed length of input.""" 126 input_tensor = constant_op.constant([b"the", b"quick", b"brown"]) 127 tokens, labels = text.skip_gram_sample( 128 input_tensor, min_skips=100, max_skips=100) 129 expected_tokens, expected_labels = self._split_tokens_labels([ 130 (b"the", b"quick"), 131 (b"the", b"brown"), 132 (b"quick", b"the"), 133 (b"quick", b"brown"), 134 (b"brown", b"the"), 135 (b"brown", b"quick"), 136 ]) 137 with self.cached_session(): 138 self.assertAllEqual(expected_tokens, tokens.eval()) 139 self.assertAllEqual(expected_labels, labels.eval()) 140 141 def test_skip_gram_sample_start_limit(self): 142 """Tests skip-gram over a limited portion of the input.""" 143 input_tensor = constant_op.constant( 144 [b"foo", b"the", b"quick", b"brown", b"bar"]) 145 tokens, labels = text.skip_gram_sample( 146 input_tensor, min_skips=1, max_skips=1, start=1, limit=3) 147 expected_tokens, expected_labels = self._split_tokens_labels([ 148 (b"the", b"quick"), 149 (b"quick", b"the"), 150 (b"quick", b"brown"), 151 (b"brown", b"quick"), 152 ]) 153 with self.cached_session(): 154 self.assertAllEqual(expected_tokens, tokens.eval()) 155 self.assertAllEqual(expected_labels, labels.eval()) 156 157 def test_skip_gram_sample_limit_exceeds(self): 158 """Tests skip-gram when limit exceeds the length of the input.""" 159 input_tensor = constant_op.constant([b"foo", b"the", b"quick", b"brown"]) 160 tokens, labels = text.skip_gram_sample( 161 input_tensor, min_skips=1, max_skips=1, start=1, limit=100) 162 expected_tokens, expected_labels = self._split_tokens_labels([ 163 (b"the", b"quick"), 164 (b"quick", b"the"), 165 (b"quick", b"brown"), 166 (b"brown", b"quick"), 167 ]) 168 with self.cached_session(): 169 self.assertAllEqual(expected_tokens, tokens.eval()) 170 self.assertAllEqual(expected_labels, labels.eval()) 171 172 def test_skip_gram_sample_random_skips(self): 173 """Tests skip-gram with min_skips != max_skips, with random output.""" 174 # The number of outputs is non-deterministic in this case, so set random 175 # seed to help ensure the outputs remain constant for this test case. 176 random_seed.set_random_seed(42) 177 178 input_tensor = constant_op.constant( 179 [b"the", b"quick", b"brown", b"fox", b"jumps", b"over"]) 180 tokens, labels = text.skip_gram_sample( 181 input_tensor, min_skips=1, max_skips=2, seed=9) 182 expected_tokens, expected_labels = self._split_tokens_labels([ 183 (b"the", b"quick"), 184 (b"the", b"brown"), 185 (b"quick", b"the"), 186 (b"quick", b"brown"), 187 (b"quick", b"fox"), 188 (b"brown", b"the"), 189 (b"brown", b"quick"), 190 (b"brown", b"fox"), 191 (b"brown", b"jumps"), 192 (b"fox", b"brown"), 193 (b"fox", b"jumps"), 194 (b"jumps", b"fox"), 195 (b"jumps", b"over"), 196 (b"over", b"fox"), 197 (b"over", b"jumps"), 198 ]) 199 with self.cached_session() as sess: 200 tokens_eval, labels_eval = sess.run([tokens, labels]) 201 self.assertAllEqual(expected_tokens, tokens_eval) 202 self.assertAllEqual(expected_labels, labels_eval) 203 204 def test_skip_gram_sample_random_skips_default_seed(self): 205 """Tests outputs are still random when no op-level seed is specified.""" 206 # This is needed since tests set a graph-level seed by default. We want to 207 # explicitly avoid setting both graph-level seed and op-level seed, to 208 # simulate behavior under non-test settings when the user doesn't provide a 209 # seed to us. This results in random_seed.get_seed() returning None for both 210 # seeds, forcing the C++ kernel to execute its default seed logic. 211 random_seed.set_random_seed(None) 212 213 # Uses an input tensor with 10 words, with possible skip ranges in [1, 214 # 5]. Thus, the probability that two random samplings would result in the 215 # same outputs is 1/5^10 ~ 1e-7 (aka the probability of this test being 216 # flaky). 217 input_tensor = constant_op.constant([str(x) for x in range(10)]) 218 219 # Do not provide an op-level seed here! 220 tokens_1, labels_1 = text.skip_gram_sample( 221 input_tensor, min_skips=1, max_skips=5) 222 tokens_2, labels_2 = text.skip_gram_sample( 223 input_tensor, min_skips=1, max_skips=5) 224 225 with self.cached_session() as sess: 226 tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run( 227 [tokens_1, labels_1, tokens_2, labels_2]) 228 229 if len(tokens_1_eval) == len(tokens_2_eval): 230 self.assertNotEqual(tokens_1_eval.tolist(), tokens_2_eval.tolist()) 231 if len(labels_1_eval) == len(labels_2_eval): 232 self.assertNotEqual(labels_1_eval.tolist(), labels_2_eval.tolist()) 233 234 def test_skip_gram_sample_batch(self): 235 """Tests skip-gram with batching.""" 236 input_tensor = constant_op.constant([b"the", b"quick", b"brown", b"fox"]) 237 tokens, labels = text.skip_gram_sample( 238 input_tensor, min_skips=1, max_skips=1, batch_size=3) 239 expected_tokens, expected_labels = self._split_tokens_labels([ 240 (b"the", b"quick"), 241 (b"quick", b"the"), 242 (b"quick", b"brown"), 243 (b"brown", b"quick"), 244 (b"brown", b"fox"), 245 (b"fox", b"brown"), 246 ]) 247 with self.cached_session() as sess: 248 coord = coordinator.Coordinator() 249 threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) 250 251 tokens_eval, labels_eval = sess.run([tokens, labels]) 252 self.assertAllEqual(expected_tokens[:3], tokens_eval) 253 self.assertAllEqual(expected_labels[:3], labels_eval) 254 tokens_eval, labels_eval = sess.run([tokens, labels]) 255 self.assertAllEqual(expected_tokens[3:6], tokens_eval) 256 self.assertAllEqual(expected_labels[3:6], labels_eval) 257 258 coord.request_stop() 259 coord.join(threads) 260 261 def test_skip_gram_sample_non_string_input(self): 262 """Tests skip-gram with non-string input.""" 263 input_tensor = constant_op.constant([1, 2, 3], dtype=dtypes.int16) 264 tokens, labels = text.skip_gram_sample( 265 input_tensor, min_skips=1, max_skips=1) 266 expected_tokens, expected_labels = self._split_tokens_labels([ 267 (1, 2), 268 (2, 1), 269 (2, 3), 270 (3, 2), 271 ]) 272 with self.cached_session(): 273 self.assertAllEqual(expected_tokens, tokens.eval()) 274 self.assertAllEqual(expected_labels, labels.eval()) 275 276 def test_skip_gram_sample_errors(self): 277 """Tests various errors raised by skip_gram_sample().""" 278 input_tensor = constant_op.constant([b"the", b"quick", b"brown"]) 279 280 invalid_skips = ( 281 # min_skips and max_skips must be >= 0. 282 (-1, 2), 283 (1, -2), 284 # min_skips must be <= max_skips. 285 (2, 1)) 286 for min_skips, max_skips in invalid_skips: 287 tokens, labels = text.skip_gram_sample( 288 input_tensor, min_skips=min_skips, max_skips=max_skips) 289 with self.cached_session() as sess, self.assertRaises( 290 errors.InvalidArgumentError): 291 sess.run([tokens, labels]) 292 293 # input_tensor must be of rank 1. 294 with self.assertRaises(ValueError): 295 invalid_tensor = constant_op.constant([[b"the"], [b"quick"], [b"brown"]]) 296 text.skip_gram_sample(invalid_tensor) 297 298 # vocab_freq_table must be provided if vocab_min_count, vocab_subsampling, 299 # or corpus_size is specified. 300 dummy_input = constant_op.constant([""]) 301 with self.assertRaises(ValueError): 302 text.skip_gram_sample( 303 dummy_input, vocab_freq_table=None, vocab_min_count=1) 304 with self.assertRaises(ValueError): 305 text.skip_gram_sample( 306 dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5) 307 with self.assertRaises(ValueError): 308 text.skip_gram_sample(dummy_input, vocab_freq_table=None, corpus_size=100) 309 with self.assertRaises(ValueError): 310 text.skip_gram_sample( 311 dummy_input, 312 vocab_freq_table=None, 313 vocab_subsampling=1e-5, 314 corpus_size=100) 315 316 # vocab_subsampling and corpus_size must both be present or absent. 317 dummy_table = lookup.HashTable( 318 lookup.KeyValueTensorInitializer([b"foo"], [10]), -1) 319 with self.assertRaises(ValueError): 320 text.skip_gram_sample( 321 dummy_input, 322 vocab_freq_table=dummy_table, 323 vocab_subsampling=None, 324 corpus_size=100) 325 with self.assertRaises(ValueError): 326 text.skip_gram_sample( 327 dummy_input, 328 vocab_freq_table=dummy_table, 329 vocab_subsampling=1e-5, 330 corpus_size=None) 331 332 def test_filter_input_filter_vocab(self): 333 """Tests input filtering based on vocab frequency table and thresholds.""" 334 input_tensor = constant_op.constant( 335 [b"the", b"answer", b"to", b"life", b"and", b"universe"]) 336 keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"]) 337 values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64) 338 vocab_freq_table = lookup.HashTable( 339 lookup.KeyValueTensorInitializer(keys, values), -1) 340 341 with self.cached_session(): 342 vocab_freq_table.initializer.run() 343 344 # No vocab_freq_table specified - output should be the same as input. 345 no_table_output = skip_gram_ops._filter_input( 346 input_tensor=input_tensor, 347 vocab_freq_table=None, 348 vocab_min_count=None, 349 vocab_subsampling=None, 350 corpus_size=None, 351 seed=None) 352 self.assertAllEqual(input_tensor.eval(), no_table_output.eval()) 353 354 # vocab_freq_table specified, but no vocab_min_count - output should have 355 # filtered out tokens not in the table (b"answer"). 356 table_output = skip_gram_ops._filter_input( 357 input_tensor=input_tensor, 358 vocab_freq_table=vocab_freq_table, 359 vocab_min_count=None, 360 vocab_subsampling=None, 361 corpus_size=None, 362 seed=None) 363 self.assertAllEqual([b"the", b"to", b"life", b"and", b"universe"], 364 table_output.eval()) 365 366 # vocab_freq_table and vocab_min_count specified - output should have 367 # filtered out tokens whose frequencies are below the threshold 368 # (b"and": 0, b"life": 1). 369 threshold_output = skip_gram_ops._filter_input( 370 input_tensor=input_tensor, 371 vocab_freq_table=vocab_freq_table, 372 vocab_min_count=2, 373 vocab_subsampling=None, 374 corpus_size=None, 375 seed=None) 376 self.assertAllEqual([b"the", b"to", b"universe"], threshold_output.eval()) 377 378 def test_filter_input_subsample_vocab(self): 379 """Tests input filtering based on vocab subsampling.""" 380 # The outputs are non-deterministic, so set random seed to help ensure that 381 # the outputs remain constant for testing. 382 random_seed.set_random_seed(42) 383 384 input_tensor = constant_op.constant([ 385 # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57. 386 b"the", 387 b"answer", # Not in vocab. (Always discarded) 388 b"to", # keep_prob = 0.75. 389 b"life", # keep_prob > 1. (Always kept) 390 b"and", # keep_prob = 0.48. 391 b"universe" # Below vocab threshold of 3. (Always discarded) 392 ]) 393 keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"]) 394 values = constant_op.constant([40, 8, 30, 20, 2], dtypes.int64) 395 vocab_freq_table = lookup.HashTable( 396 lookup.KeyValueTensorInitializer(keys, values), -1) 397 398 with self.cached_session(): 399 vocab_freq_table.initializer.run() 400 output = skip_gram_ops._filter_input( 401 input_tensor=input_tensor, 402 vocab_freq_table=vocab_freq_table, 403 vocab_min_count=3, 404 vocab_subsampling=0.05, 405 corpus_size=math_ops.reduce_sum(values), 406 seed=9) 407 self.assertAllEqual([b"the", b"to", b"life", b"and"], output.eval()) 408 409 def _make_text_vocab_freq_file(self): 410 filepath = os.path.join(test.get_temp_dir(), "vocab_freq.txt") 411 with open(filepath, "w") as f: 412 writer = csv.writer(f) 413 writer.writerows([ 414 ["and", 40], 415 ["life", 8], 416 ["the", 30], 417 ["to", 20], 418 ["universe", 2], 419 ]) 420 return filepath 421 422 def _make_text_vocab_float_file(self): 423 filepath = os.path.join(test.get_temp_dir(), "vocab_freq_float.txt") 424 with open(filepath, "w") as f: 425 writer = csv.writer(f) 426 writer.writerows([ 427 ["and", 0.4], 428 ["life", 0.08], 429 ["the", 0.3], 430 ["to", 0.2], 431 ["universe", 0.02], 432 ]) 433 return filepath 434 435 def test_skip_gram_sample_with_text_vocab_filter_vocab(self): 436 """Tests skip-gram sampling with text vocab and freq threshold filtering.""" 437 input_tensor = constant_op.constant([ 438 b"the", 439 b"answer", # Will be filtered before candidate generation. 440 b"to", 441 b"life", 442 b"and", 443 b"universe" # Will be filtered before candidate generation. 444 ]) 445 446 # b"answer" is not in vocab file, and b"universe"'s frequency is below 447 # threshold of 3. 448 vocab_freq_file = self._make_text_vocab_freq_file() 449 450 tokens, labels = text.skip_gram_sample_with_text_vocab( 451 input_tensor=input_tensor, 452 vocab_freq_file=vocab_freq_file, 453 vocab_token_index=0, 454 vocab_freq_index=1, 455 vocab_min_count=3, 456 min_skips=1, 457 max_skips=1) 458 459 expected_tokens, expected_labels = self._split_tokens_labels([ 460 (b"the", b"to"), 461 (b"to", b"the"), 462 (b"to", b"life"), 463 (b"life", b"to"), 464 (b"life", b"and"), 465 (b"and", b"life"), 466 ]) 467 with self.cached_session(): 468 lookup_ops.tables_initializer().run() 469 self.assertAllEqual(expected_tokens, tokens.eval()) 470 self.assertAllEqual(expected_labels, labels.eval()) 471 472 def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, vocab_min_count, 473 vocab_freq_dtype, corpus_size=None): 474 # The outputs are non-deterministic, so set random seed to help ensure that 475 # the outputs remain constant for testing. 476 random_seed.set_random_seed(42) 477 478 input_tensor = constant_op.constant([ 479 # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57. 480 b"the", 481 b"answer", # Not in vocab. (Always discarded) 482 b"to", # keep_prob = 0.75. 483 b"life", # keep_prob > 1. (Always kept) 484 b"and", # keep_prob = 0.48. 485 b"universe" # Below vocab threshold of 3. (Always discarded) 486 ]) 487 # keep_prob calculated from vocab file with relative frequencies of: 488 # and: 40 489 # life: 8 490 # the: 30 491 # to: 20 492 # universe: 2 493 494 tokens, labels = text.skip_gram_sample_with_text_vocab( 495 input_tensor=input_tensor, 496 vocab_freq_file=vocab_freq_file, 497 vocab_token_index=0, 498 vocab_freq_index=1, 499 vocab_freq_dtype=vocab_freq_dtype, 500 vocab_min_count=vocab_min_count, 501 vocab_subsampling=0.05, 502 corpus_size=corpus_size, 503 min_skips=1, 504 max_skips=1, 505 seed=123) 506 507 expected_tokens, expected_labels = self._split_tokens_labels([ 508 (b"the", b"to"), 509 (b"to", b"the"), 510 (b"to", b"life"), 511 (b"life", b"to"), 512 ]) 513 with self.cached_session() as sess: 514 lookup_ops.tables_initializer().run() 515 tokens_eval, labels_eval = sess.run([tokens, labels]) 516 self.assertAllEqual(expected_tokens, tokens_eval) 517 self.assertAllEqual(expected_labels, labels_eval) 518 519 def test_skip_gram_sample_with_text_vocab_subsample_vocab(self): 520 """Tests skip-gram sampling with text vocab and vocab subsampling.""" 521 # Vocab file frequencies 522 # and: 40 523 # life: 8 524 # the: 30 525 # to: 20 526 # universe: 2 527 # 528 # corpus_size for the above vocab is 40+8+30+20+2 = 100. 529 text_vocab_freq_file = self._make_text_vocab_freq_file() 530 self._text_vocab_subsample_vocab_helper( 531 vocab_freq_file=text_vocab_freq_file, 532 vocab_min_count=3, 533 vocab_freq_dtype=dtypes.int64) 534 self._text_vocab_subsample_vocab_helper( 535 vocab_freq_file=text_vocab_freq_file, 536 vocab_min_count=3, 537 vocab_freq_dtype=dtypes.int64, 538 corpus_size=100) 539 540 # The user-supplied corpus_size should not be less than the sum of all 541 # the frequency counts of vocab_freq_file, which is 100. 542 with self.assertRaises(ValueError): 543 self._text_vocab_subsample_vocab_helper( 544 vocab_freq_file=text_vocab_freq_file, 545 vocab_min_count=3, 546 vocab_freq_dtype=dtypes.int64, 547 corpus_size=99) 548 549 def test_skip_gram_sample_with_text_vocab_subsample_vocab_float(self): 550 """Tests skip-gram sampling with text vocab and subsampling with floats.""" 551 # Vocab file frequencies 552 # and: 0.4 553 # life: 0.08 554 # the: 0.3 555 # to: 0.2 556 # universe: 0.02 557 # 558 # corpus_size for the above vocab is 0.4+0.08+0.3+0.2+0.02 = 1. 559 text_vocab_float_file = self._make_text_vocab_float_file() 560 self._text_vocab_subsample_vocab_helper( 561 vocab_freq_file=text_vocab_float_file, 562 vocab_min_count=0.03, 563 vocab_freq_dtype=dtypes.float32) 564 self._text_vocab_subsample_vocab_helper( 565 vocab_freq_file=text_vocab_float_file, 566 vocab_min_count=0.03, 567 vocab_freq_dtype=dtypes.float32, 568 corpus_size=1.0) 569 570 # The user-supplied corpus_size should not be less than the sum of all 571 # the frequency counts of vocab_freq_file, which is 1. 572 with self.assertRaises(ValueError): 573 self._text_vocab_subsample_vocab_helper( 574 vocab_freq_file=text_vocab_float_file, 575 vocab_min_count=0.03, 576 vocab_freq_dtype=dtypes.float32, 577 corpus_size=0.99) 578 579 def test_skip_gram_sample_with_text_vocab_errors(self): 580 """Tests various errors raised by skip_gram_sample_with_text_vocab().""" 581 dummy_input = constant_op.constant([""]) 582 vocab_freq_file = self._make_text_vocab_freq_file() 583 584 invalid_indices = ( 585 # vocab_token_index can't be negative. 586 (-1, 0), 587 # vocab_freq_index can't be negative. 588 (0, -1), 589 # vocab_token_index can't be equal to vocab_freq_index. 590 (0, 0), 591 (1, 1), 592 # vocab_freq_file only has two columns. 593 (0, 2), 594 (2, 0)) 595 596 for vocab_token_index, vocab_freq_index in invalid_indices: 597 with self.assertRaises(ValueError): 598 text.skip_gram_sample_with_text_vocab( 599 input_tensor=dummy_input, 600 vocab_freq_file=vocab_freq_file, 601 vocab_token_index=vocab_token_index, 602 vocab_freq_index=vocab_freq_index) 603 604 605if __name__ == "__main__": 606 test.main() 607