1# encoding: utf-8 2# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Text processor tests.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21from __future__ import unicode_literals 22 23from tensorflow.contrib.learn.python.learn.preprocessing import CategoricalVocabulary 24from tensorflow.contrib.learn.python.learn.preprocessing import text 25from tensorflow.python.platform import test 26 27 28class TextTest(test.TestCase): 29 """Text processor tests.""" 30 31 def testTokenizer(self): 32 words = text.tokenizer( 33 ["a b c", "a\nb\nc", "a, b - c", "фыв выф", "你好 怎么样"]) 34 self.assertEqual( 35 list(words), [["a", "b", "c"], ["a", "b", "c"], ["a", "b", "-", "c"], 36 ["фыв", "выф"], ["你好", "怎么样"]]) 37 38 def testByteProcessor(self): 39 processor = text.ByteProcessor(max_document_length=8) 40 inp = ["abc", "фыва", "фыва", b"abc", "12345678901234567890"] 41 res = list(processor.fit_transform(inp)) 42 self.assertAllEqual(res, [[97, 98, 99, 0, 0, 0, 0, 0], 43 [209, 132, 209, 139, 208, 178, 208, 176], 44 [209, 132, 209, 139, 208, 178, 208, 176], 45 [97, 98, 99, 0, 0, 0, 0, 0], 46 [49, 50, 51, 52, 53, 54, 55, 56]]) 47 res = list(processor.reverse(res)) 48 self.assertAllEqual(res, ["abc", "фыва", "фыва", "abc", "12345678"]) 49 50 def testVocabularyProcessor(self): 51 vocab_processor = text.VocabularyProcessor( 52 max_document_length=4, min_frequency=1) 53 tokens = vocab_processor.fit_transform(["a b c", "a\nb\nc", "a, b - c"]) 54 self.assertAllEqual( 55 list(tokens), [[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 3]]) 56 57 def testVocabularyProcessorSaveRestore(self): 58 filename = test.get_temp_dir() + "test.vocab" 59 vocab_processor = text.VocabularyProcessor( 60 max_document_length=4, min_frequency=1) 61 tokens = vocab_processor.fit_transform(["a b c", "a\nb\nc", "a, b - c"]) 62 vocab_processor.save(filename) 63 new_vocab = text.VocabularyProcessor.restore(filename) 64 tokens = new_vocab.transform(["a b c"]) 65 self.assertAllEqual(list(tokens), [[1, 2, 3, 0]]) 66 67 def testExistingVocabularyProcessor(self): 68 vocab = CategoricalVocabulary() 69 vocab.get("A") 70 vocab.get("B") 71 vocab.freeze() 72 vocab_processor = text.VocabularyProcessor( 73 max_document_length=4, vocabulary=vocab, tokenizer_fn=list) 74 tokens = vocab_processor.fit_transform(["ABC", "CBABAF"]) 75 self.assertAllEqual(list(tokens), [[1, 2, 0, 0], [0, 2, 1, 2]]) 76 77 78if __name__ == "__main__": 79 test.main() 80