• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# encoding: utf-8
2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Text processor tests."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21from __future__ import unicode_literals
22
23from tensorflow.contrib.learn.python.learn.preprocessing import CategoricalVocabulary
24from tensorflow.contrib.learn.python.learn.preprocessing import text
25from tensorflow.python.platform import test
26
27
28class TextTest(test.TestCase):
29  """Text processor tests."""
30
31  def testTokenizer(self):
32    words = text.tokenizer(
33        ["a b c", "a\nb\nc", "a, b - c", "фыв выф", "你好 怎么样"])
34    self.assertEqual(
35        list(words), [["a", "b", "c"], ["a", "b", "c"], ["a", "b", "-", "c"],
36                      ["фыв", "выф"], ["你好", "怎么样"]])
37
38  def testByteProcessor(self):
39    processor = text.ByteProcessor(max_document_length=8)
40    inp = ["abc", "фыва", "фыва", b"abc", "12345678901234567890"]
41    res = list(processor.fit_transform(inp))
42    self.assertAllEqual(res, [[97, 98, 99, 0, 0, 0, 0, 0],
43                              [209, 132, 209, 139, 208, 178, 208, 176],
44                              [209, 132, 209, 139, 208, 178, 208, 176],
45                              [97, 98, 99, 0, 0, 0, 0, 0],
46                              [49, 50, 51, 52, 53, 54, 55, 56]])
47    res = list(processor.reverse(res))
48    self.assertAllEqual(res, ["abc", "фыва", "фыва", "abc", "12345678"])
49
50  def testVocabularyProcessor(self):
51    vocab_processor = text.VocabularyProcessor(
52        max_document_length=4, min_frequency=1)
53    tokens = vocab_processor.fit_transform(["a b c", "a\nb\nc", "a, b - c"])
54    self.assertAllEqual(
55        list(tokens), [[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 3]])
56
57  def testVocabularyProcessorSaveRestore(self):
58    filename = test.get_temp_dir() + "test.vocab"
59    vocab_processor = text.VocabularyProcessor(
60        max_document_length=4, min_frequency=1)
61    tokens = vocab_processor.fit_transform(["a b c", "a\nb\nc", "a, b - c"])
62    vocab_processor.save(filename)
63    new_vocab = text.VocabularyProcessor.restore(filename)
64    tokens = new_vocab.transform(["a b c"])
65    self.assertAllEqual(list(tokens), [[1, 2, 3, 0]])
66
67  def testExistingVocabularyProcessor(self):
68    vocab = CategoricalVocabulary()
69    vocab.get("A")
70    vocab.get("B")
71    vocab.freeze()
72    vocab_processor = text.VocabularyProcessor(
73        max_document_length=4, vocabulary=vocab, tokenizer_fn=list)
74    tokens = vocab_processor.fit_transform(["ABC", "CBABAF"])
75    self.assertAllEqual(list(tokens), [[1, 2, 0, 0], [0, 2, 1, 2]])
76
77
78if __name__ == "__main__":
79  test.main()
80