• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Skip-gram sampling ops tests."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import csv
21import os
22
23from tensorflow.contrib import lookup
24from tensorflow.contrib import text
25from tensorflow.contrib.text.python.ops import skip_gram_ops
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import random_seed
30from tensorflow.python.ops import lookup_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.platform import test
33from tensorflow.python.training import coordinator
34from tensorflow.python.training import queue_runner_impl
35
36
37class SkipGramOpsTest(test.TestCase):
38
39  def _split_tokens_labels(self, output):
40    tokens = [x[0] for x in output]
41    labels = [x[1] for x in output]
42    return tokens, labels
43
44  def test_skip_gram_sample_skips_2(self):
45    """Tests skip-gram with min_skips = max_skips = 2."""
46    input_tensor = constant_op.constant(
47        [b"the", b"quick", b"brown", b"fox", b"jumps"])
48    tokens, labels = text.skip_gram_sample(
49        input_tensor, min_skips=2, max_skips=2)
50    expected_tokens, expected_labels = self._split_tokens_labels([
51        (b"the", b"quick"),
52        (b"the", b"brown"),
53        (b"quick", b"the"),
54        (b"quick", b"brown"),
55        (b"quick", b"fox"),
56        (b"brown", b"the"),
57        (b"brown", b"quick"),
58        (b"brown", b"fox"),
59        (b"brown", b"jumps"),
60        (b"fox", b"quick"),
61        (b"fox", b"brown"),
62        (b"fox", b"jumps"),
63        (b"jumps", b"brown"),
64        (b"jumps", b"fox"),
65    ])
66    with self.cached_session():
67      self.assertAllEqual(expected_tokens, tokens.eval())
68      self.assertAllEqual(expected_labels, labels.eval())
69
70  def test_skip_gram_sample_emit_self(self):
71    """Tests skip-gram with emit_self_as_target = True."""
72    input_tensor = constant_op.constant(
73        [b"the", b"quick", b"brown", b"fox", b"jumps"])
74    tokens, labels = text.skip_gram_sample(
75        input_tensor, min_skips=2, max_skips=2, emit_self_as_target=True)
76    expected_tokens, expected_labels = self._split_tokens_labels([
77        (b"the", b"the"),
78        (b"the", b"quick"),
79        (b"the", b"brown"),
80        (b"quick", b"the"),
81        (b"quick", b"quick"),
82        (b"quick", b"brown"),
83        (b"quick", b"fox"),
84        (b"brown", b"the"),
85        (b"brown", b"quick"),
86        (b"brown", b"brown"),
87        (b"brown", b"fox"),
88        (b"brown", b"jumps"),
89        (b"fox", b"quick"),
90        (b"fox", b"brown"),
91        (b"fox", b"fox"),
92        (b"fox", b"jumps"),
93        (b"jumps", b"brown"),
94        (b"jumps", b"fox"),
95        (b"jumps", b"jumps"),
96    ])
97    with self.cached_session():
98      self.assertAllEqual(expected_tokens, tokens.eval())
99      self.assertAllEqual(expected_labels, labels.eval())
100
101  def test_skip_gram_sample_skips_0(self):
102    """Tests skip-gram with min_skips = max_skips = 0."""
103    input_tensor = constant_op.constant([b"the", b"quick", b"brown"])
104
105    # If emit_self_as_target is False (default), output will be empty.
106    tokens, labels = text.skip_gram_sample(
107        input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False)
108    with self.cached_session():
109      self.assertEqual(0, tokens.eval().size)
110      self.assertEqual(0, labels.eval().size)
111
112    # If emit_self_as_target is True, each token will be its own label.
113    tokens, labels = text.skip_gram_sample(
114        input_tensor, min_skips=0, max_skips=0, emit_self_as_target=True)
115    expected_tokens, expected_labels = self._split_tokens_labels([
116        (b"the", b"the"),
117        (b"quick", b"quick"),
118        (b"brown", b"brown"),
119    ])
120    with self.cached_session():
121      self.assertAllEqual(expected_tokens, tokens.eval())
122      self.assertAllEqual(expected_labels, labels.eval())
123
124  def test_skip_gram_sample_skips_exceed_length(self):
125    """Tests skip-gram when min/max_skips exceed length of input."""
126    input_tensor = constant_op.constant([b"the", b"quick", b"brown"])
127    tokens, labels = text.skip_gram_sample(
128        input_tensor, min_skips=100, max_skips=100)
129    expected_tokens, expected_labels = self._split_tokens_labels([
130        (b"the", b"quick"),
131        (b"the", b"brown"),
132        (b"quick", b"the"),
133        (b"quick", b"brown"),
134        (b"brown", b"the"),
135        (b"brown", b"quick"),
136    ])
137    with self.cached_session():
138      self.assertAllEqual(expected_tokens, tokens.eval())
139      self.assertAllEqual(expected_labels, labels.eval())
140
141  def test_skip_gram_sample_start_limit(self):
142    """Tests skip-gram over a limited portion of the input."""
143    input_tensor = constant_op.constant(
144        [b"foo", b"the", b"quick", b"brown", b"bar"])
145    tokens, labels = text.skip_gram_sample(
146        input_tensor, min_skips=1, max_skips=1, start=1, limit=3)
147    expected_tokens, expected_labels = self._split_tokens_labels([
148        (b"the", b"quick"),
149        (b"quick", b"the"),
150        (b"quick", b"brown"),
151        (b"brown", b"quick"),
152    ])
153    with self.cached_session():
154      self.assertAllEqual(expected_tokens, tokens.eval())
155      self.assertAllEqual(expected_labels, labels.eval())
156
157  def test_skip_gram_sample_limit_exceeds(self):
158    """Tests skip-gram when limit exceeds the length of the input."""
159    input_tensor = constant_op.constant([b"foo", b"the", b"quick", b"brown"])
160    tokens, labels = text.skip_gram_sample(
161        input_tensor, min_skips=1, max_skips=1, start=1, limit=100)
162    expected_tokens, expected_labels = self._split_tokens_labels([
163        (b"the", b"quick"),
164        (b"quick", b"the"),
165        (b"quick", b"brown"),
166        (b"brown", b"quick"),
167    ])
168    with self.cached_session():
169      self.assertAllEqual(expected_tokens, tokens.eval())
170      self.assertAllEqual(expected_labels, labels.eval())
171
172  def test_skip_gram_sample_random_skips(self):
173    """Tests skip-gram with min_skips != max_skips, with random output."""
174    # The number of outputs is non-deterministic in this case, so set random
175    # seed to help ensure the outputs remain constant for this test case.
176    random_seed.set_random_seed(42)
177
178    input_tensor = constant_op.constant(
179        [b"the", b"quick", b"brown", b"fox", b"jumps", b"over"])
180    tokens, labels = text.skip_gram_sample(
181        input_tensor, min_skips=1, max_skips=2, seed=9)
182    expected_tokens, expected_labels = self._split_tokens_labels([
183        (b"the", b"quick"),
184        (b"the", b"brown"),
185        (b"quick", b"the"),
186        (b"quick", b"brown"),
187        (b"quick", b"fox"),
188        (b"brown", b"the"),
189        (b"brown", b"quick"),
190        (b"brown", b"fox"),
191        (b"brown", b"jumps"),
192        (b"fox", b"brown"),
193        (b"fox", b"jumps"),
194        (b"jumps", b"fox"),
195        (b"jumps", b"over"),
196        (b"over", b"fox"),
197        (b"over", b"jumps"),
198    ])
199    with self.cached_session() as sess:
200      tokens_eval, labels_eval = sess.run([tokens, labels])
201      self.assertAllEqual(expected_tokens, tokens_eval)
202      self.assertAllEqual(expected_labels, labels_eval)
203
204  def test_skip_gram_sample_random_skips_default_seed(self):
205    """Tests outputs are still random when no op-level seed is specified."""
206    # This is needed since tests set a graph-level seed by default. We want to
207    # explicitly avoid setting both graph-level seed and op-level seed, to
208    # simulate behavior under non-test settings when the user doesn't provide a
209    # seed to us. This results in random_seed.get_seed() returning None for both
210    # seeds, forcing the C++ kernel to execute its default seed logic.
211    random_seed.set_random_seed(None)
212
213    # Uses an input tensor with 10 words, with possible skip ranges in [1,
214    # 5]. Thus, the probability that two random samplings would result in the
215    # same outputs is 1/5^10 ~ 1e-7 (aka the probability of this test being
216    # flaky).
217    input_tensor = constant_op.constant([str(x) for x in range(10)])
218
219    # Do not provide an op-level seed here!
220    tokens_1, labels_1 = text.skip_gram_sample(
221        input_tensor, min_skips=1, max_skips=5)
222    tokens_2, labels_2 = text.skip_gram_sample(
223        input_tensor, min_skips=1, max_skips=5)
224
225    with self.cached_session() as sess:
226      tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run(
227          [tokens_1, labels_1, tokens_2, labels_2])
228
229    if len(tokens_1_eval) == len(tokens_2_eval):
230      self.assertNotEqual(tokens_1_eval.tolist(), tokens_2_eval.tolist())
231    if len(labels_1_eval) == len(labels_2_eval):
232      self.assertNotEqual(labels_1_eval.tolist(), labels_2_eval.tolist())
233
234  def test_skip_gram_sample_batch(self):
235    """Tests skip-gram with batching."""
236    input_tensor = constant_op.constant([b"the", b"quick", b"brown", b"fox"])
237    tokens, labels = text.skip_gram_sample(
238        input_tensor, min_skips=1, max_skips=1, batch_size=3)
239    expected_tokens, expected_labels = self._split_tokens_labels([
240        (b"the", b"quick"),
241        (b"quick", b"the"),
242        (b"quick", b"brown"),
243        (b"brown", b"quick"),
244        (b"brown", b"fox"),
245        (b"fox", b"brown"),
246    ])
247    with self.cached_session() as sess:
248      coord = coordinator.Coordinator()
249      threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
250
251      tokens_eval, labels_eval = sess.run([tokens, labels])
252      self.assertAllEqual(expected_tokens[:3], tokens_eval)
253      self.assertAllEqual(expected_labels[:3], labels_eval)
254      tokens_eval, labels_eval = sess.run([tokens, labels])
255      self.assertAllEqual(expected_tokens[3:6], tokens_eval)
256      self.assertAllEqual(expected_labels[3:6], labels_eval)
257
258      coord.request_stop()
259      coord.join(threads)
260
261  def test_skip_gram_sample_non_string_input(self):
262    """Tests skip-gram with non-string input."""
263    input_tensor = constant_op.constant([1, 2, 3], dtype=dtypes.int16)
264    tokens, labels = text.skip_gram_sample(
265        input_tensor, min_skips=1, max_skips=1)
266    expected_tokens, expected_labels = self._split_tokens_labels([
267        (1, 2),
268        (2, 1),
269        (2, 3),
270        (3, 2),
271    ])
272    with self.cached_session():
273      self.assertAllEqual(expected_tokens, tokens.eval())
274      self.assertAllEqual(expected_labels, labels.eval())
275
276  def test_skip_gram_sample_errors(self):
277    """Tests various errors raised by skip_gram_sample()."""
278    input_tensor = constant_op.constant([b"the", b"quick", b"brown"])
279
280    invalid_skips = (
281        # min_skips and max_skips must be >= 0.
282        (-1, 2),
283        (1, -2),
284        # min_skips must be <= max_skips.
285        (2, 1))
286    for min_skips, max_skips in invalid_skips:
287      tokens, labels = text.skip_gram_sample(
288          input_tensor, min_skips=min_skips, max_skips=max_skips)
289      with self.cached_session() as sess, self.assertRaises(
290          errors.InvalidArgumentError):
291        sess.run([tokens, labels])
292
293    # input_tensor must be of rank 1.
294    with self.assertRaises(ValueError):
295      invalid_tensor = constant_op.constant([[b"the"], [b"quick"], [b"brown"]])
296      text.skip_gram_sample(invalid_tensor)
297
298    # vocab_freq_table must be provided if vocab_min_count, vocab_subsampling,
299    # or corpus_size is specified.
300    dummy_input = constant_op.constant([""])
301    with self.assertRaises(ValueError):
302      text.skip_gram_sample(
303          dummy_input, vocab_freq_table=None, vocab_min_count=1)
304    with self.assertRaises(ValueError):
305      text.skip_gram_sample(
306          dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5)
307    with self.assertRaises(ValueError):
308      text.skip_gram_sample(dummy_input, vocab_freq_table=None, corpus_size=100)
309    with self.assertRaises(ValueError):
310      text.skip_gram_sample(
311          dummy_input,
312          vocab_freq_table=None,
313          vocab_subsampling=1e-5,
314          corpus_size=100)
315
316    # vocab_subsampling and corpus_size must both be present or absent.
317    dummy_table = lookup.HashTable(
318        lookup.KeyValueTensorInitializer([b"foo"], [10]), -1)
319    with self.assertRaises(ValueError):
320      text.skip_gram_sample(
321          dummy_input,
322          vocab_freq_table=dummy_table,
323          vocab_subsampling=None,
324          corpus_size=100)
325    with self.assertRaises(ValueError):
326      text.skip_gram_sample(
327          dummy_input,
328          vocab_freq_table=dummy_table,
329          vocab_subsampling=1e-5,
330          corpus_size=None)
331
332  def test_filter_input_filter_vocab(self):
333    """Tests input filtering based on vocab frequency table and thresholds."""
334    input_tensor = constant_op.constant(
335        [b"the", b"answer", b"to", b"life", b"and", b"universe"])
336    keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"])
337    values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
338    vocab_freq_table = lookup.HashTable(
339        lookup.KeyValueTensorInitializer(keys, values), -1)
340
341    with self.cached_session():
342      vocab_freq_table.initializer.run()
343
344      # No vocab_freq_table specified - output should be the same as input.
345      no_table_output = skip_gram_ops._filter_input(
346          input_tensor=input_tensor,
347          vocab_freq_table=None,
348          vocab_min_count=None,
349          vocab_subsampling=None,
350          corpus_size=None,
351          seed=None)
352      self.assertAllEqual(input_tensor.eval(), no_table_output.eval())
353
354      # vocab_freq_table specified, but no vocab_min_count - output should have
355      # filtered out tokens not in the table (b"answer").
356      table_output = skip_gram_ops._filter_input(
357          input_tensor=input_tensor,
358          vocab_freq_table=vocab_freq_table,
359          vocab_min_count=None,
360          vocab_subsampling=None,
361          corpus_size=None,
362          seed=None)
363      self.assertAllEqual([b"the", b"to", b"life", b"and", b"universe"],
364                          table_output.eval())
365
366      # vocab_freq_table and vocab_min_count specified - output should have
367      # filtered out tokens whose frequencies are below the threshold
368      # (b"and": 0, b"life": 1).
369      threshold_output = skip_gram_ops._filter_input(
370          input_tensor=input_tensor,
371          vocab_freq_table=vocab_freq_table,
372          vocab_min_count=2,
373          vocab_subsampling=None,
374          corpus_size=None,
375          seed=None)
376      self.assertAllEqual([b"the", b"to", b"universe"], threshold_output.eval())
377
378  def test_filter_input_subsample_vocab(self):
379    """Tests input filtering based on vocab subsampling."""
380    # The outputs are non-deterministic, so set random seed to help ensure that
381    # the outputs remain constant for testing.
382    random_seed.set_random_seed(42)
383
384    input_tensor = constant_op.constant([
385        # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
386        b"the",
387        b"answer",  # Not in vocab. (Always discarded)
388        b"to",  # keep_prob = 0.75.
389        b"life",  # keep_prob > 1. (Always kept)
390        b"and",  # keep_prob = 0.48.
391        b"universe"  # Below vocab threshold of 3. (Always discarded)
392    ])
393    keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"])
394    values = constant_op.constant([40, 8, 30, 20, 2], dtypes.int64)
395    vocab_freq_table = lookup.HashTable(
396        lookup.KeyValueTensorInitializer(keys, values), -1)
397
398    with self.cached_session():
399      vocab_freq_table.initializer.run()
400      output = skip_gram_ops._filter_input(
401          input_tensor=input_tensor,
402          vocab_freq_table=vocab_freq_table,
403          vocab_min_count=3,
404          vocab_subsampling=0.05,
405          corpus_size=math_ops.reduce_sum(values),
406          seed=9)
407      self.assertAllEqual([b"the", b"to", b"life", b"and"], output.eval())
408
409  def _make_text_vocab_freq_file(self):
410    filepath = os.path.join(test.get_temp_dir(), "vocab_freq.txt")
411    with open(filepath, "w") as f:
412      writer = csv.writer(f)
413      writer.writerows([
414          ["and", 40],
415          ["life", 8],
416          ["the", 30],
417          ["to", 20],
418          ["universe", 2],
419      ])
420    return filepath
421
422  def _make_text_vocab_float_file(self):
423    filepath = os.path.join(test.get_temp_dir(), "vocab_freq_float.txt")
424    with open(filepath, "w") as f:
425      writer = csv.writer(f)
426      writer.writerows([
427          ["and", 0.4],
428          ["life", 0.08],
429          ["the", 0.3],
430          ["to", 0.2],
431          ["universe", 0.02],
432      ])
433    return filepath
434
435  def test_skip_gram_sample_with_text_vocab_filter_vocab(self):
436    """Tests skip-gram sampling with text vocab and freq threshold filtering."""
437    input_tensor = constant_op.constant([
438        b"the",
439        b"answer",  # Will be filtered before candidate generation.
440        b"to",
441        b"life",
442        b"and",
443        b"universe"  # Will be filtered before candidate generation.
444    ])
445
446    # b"answer" is not in vocab file, and b"universe"'s frequency is below
447    # threshold of 3.
448    vocab_freq_file = self._make_text_vocab_freq_file()
449
450    tokens, labels = text.skip_gram_sample_with_text_vocab(
451        input_tensor=input_tensor,
452        vocab_freq_file=vocab_freq_file,
453        vocab_token_index=0,
454        vocab_freq_index=1,
455        vocab_min_count=3,
456        min_skips=1,
457        max_skips=1)
458
459    expected_tokens, expected_labels = self._split_tokens_labels([
460        (b"the", b"to"),
461        (b"to", b"the"),
462        (b"to", b"life"),
463        (b"life", b"to"),
464        (b"life", b"and"),
465        (b"and", b"life"),
466    ])
467    with self.cached_session():
468      lookup_ops.tables_initializer().run()
469      self.assertAllEqual(expected_tokens, tokens.eval())
470      self.assertAllEqual(expected_labels, labels.eval())
471
472  def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, vocab_min_count,
473                                         vocab_freq_dtype, corpus_size=None):
474    # The outputs are non-deterministic, so set random seed to help ensure that
475    # the outputs remain constant for testing.
476    random_seed.set_random_seed(42)
477
478    input_tensor = constant_op.constant([
479        # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
480        b"the",
481        b"answer",  # Not in vocab. (Always discarded)
482        b"to",  # keep_prob = 0.75.
483        b"life",  # keep_prob > 1. (Always kept)
484        b"and",  # keep_prob = 0.48.
485        b"universe"  # Below vocab threshold of 3. (Always discarded)
486    ])
487    # keep_prob calculated from vocab file with relative frequencies of:
488    # and: 40
489    # life: 8
490    # the: 30
491    # to: 20
492    # universe: 2
493
494    tokens, labels = text.skip_gram_sample_with_text_vocab(
495        input_tensor=input_tensor,
496        vocab_freq_file=vocab_freq_file,
497        vocab_token_index=0,
498        vocab_freq_index=1,
499        vocab_freq_dtype=vocab_freq_dtype,
500        vocab_min_count=vocab_min_count,
501        vocab_subsampling=0.05,
502        corpus_size=corpus_size,
503        min_skips=1,
504        max_skips=1,
505        seed=123)
506
507    expected_tokens, expected_labels = self._split_tokens_labels([
508        (b"the", b"to"),
509        (b"to", b"the"),
510        (b"to", b"life"),
511        (b"life", b"to"),
512    ])
513    with self.cached_session() as sess:
514      lookup_ops.tables_initializer().run()
515      tokens_eval, labels_eval = sess.run([tokens, labels])
516      self.assertAllEqual(expected_tokens, tokens_eval)
517      self.assertAllEqual(expected_labels, labels_eval)
518
519  def test_skip_gram_sample_with_text_vocab_subsample_vocab(self):
520    """Tests skip-gram sampling with text vocab and vocab subsampling."""
521    # Vocab file frequencies
522    # and: 40
523    # life: 8
524    # the: 30
525    # to: 20
526    # universe: 2
527    #
528    # corpus_size for the above vocab is 40+8+30+20+2 = 100.
529    text_vocab_freq_file = self._make_text_vocab_freq_file()
530    self._text_vocab_subsample_vocab_helper(
531        vocab_freq_file=text_vocab_freq_file,
532        vocab_min_count=3,
533        vocab_freq_dtype=dtypes.int64)
534    self._text_vocab_subsample_vocab_helper(
535        vocab_freq_file=text_vocab_freq_file,
536        vocab_min_count=3,
537        vocab_freq_dtype=dtypes.int64,
538        corpus_size=100)
539
540    # The user-supplied corpus_size should not be less than the sum of all
541    # the frequency counts of vocab_freq_file, which is 100.
542    with self.assertRaises(ValueError):
543      self._text_vocab_subsample_vocab_helper(
544          vocab_freq_file=text_vocab_freq_file,
545          vocab_min_count=3,
546          vocab_freq_dtype=dtypes.int64,
547          corpus_size=99)
548
549  def test_skip_gram_sample_with_text_vocab_subsample_vocab_float(self):
550    """Tests skip-gram sampling with text vocab and subsampling with floats."""
551    # Vocab file frequencies
552    # and: 0.4
553    # life: 0.08
554    # the: 0.3
555    # to: 0.2
556    # universe: 0.02
557    #
558    # corpus_size for the above vocab is 0.4+0.08+0.3+0.2+0.02 = 1.
559    text_vocab_float_file = self._make_text_vocab_float_file()
560    self._text_vocab_subsample_vocab_helper(
561        vocab_freq_file=text_vocab_float_file,
562        vocab_min_count=0.03,
563        vocab_freq_dtype=dtypes.float32)
564    self._text_vocab_subsample_vocab_helper(
565        vocab_freq_file=text_vocab_float_file,
566        vocab_min_count=0.03,
567        vocab_freq_dtype=dtypes.float32,
568        corpus_size=1.0)
569
570    # The user-supplied corpus_size should not be less than the sum of all
571    # the frequency counts of vocab_freq_file, which is 1.
572    with self.assertRaises(ValueError):
573      self._text_vocab_subsample_vocab_helper(
574          vocab_freq_file=text_vocab_float_file,
575          vocab_min_count=0.03,
576          vocab_freq_dtype=dtypes.float32,
577          corpus_size=0.99)
578
579  def test_skip_gram_sample_with_text_vocab_errors(self):
580    """Tests various errors raised by skip_gram_sample_with_text_vocab()."""
581    dummy_input = constant_op.constant([""])
582    vocab_freq_file = self._make_text_vocab_freq_file()
583
584    invalid_indices = (
585        # vocab_token_index can't be negative.
586        (-1, 0),
587        # vocab_freq_index can't be negative.
588        (0, -1),
589        # vocab_token_index can't be equal to vocab_freq_index.
590        (0, 0),
591        (1, 1),
592        # vocab_freq_file only has two columns.
593        (0, 2),
594        (2, 0))
595
596    for vocab_token_index, vocab_freq_index in invalid_indices:
597      with self.assertRaises(ValueError):
598        text.skip_gram_sample_with_text_vocab(
599            input_tensor=dummy_input,
600            vocab_freq_file=vocab_freq_file,
601            vocab_token_index=vocab_token_index,
602            vocab_freq_index=vocab_freq_index)
603
604
605if __name__ == "__main__":
606  test.main()
607