• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Implements a number of text preprocessing utilities (deprecated).
17
18This module and all its submodules are deprecated. See
19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
20for migration instructions.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import re
28import numpy as np
29import six
30
31from tensorflow.python.platform import gfile
32from tensorflow.python.util.deprecation import deprecated
33
34from .categorical_vocabulary import CategoricalVocabulary  # pylint: disable=g-bad-import-order
35
36try:
37  # pylint: disable=g-import-not-at-top
38  import cPickle as pickle
39except ImportError:
40  # pylint: disable=g-import-not-at-top
41  import pickle
42
43TOKENIZER_RE = re.compile(r"[A-Z]{2,}(?![a-z])|[A-Z][a-z]+(?=[A-Z])|[\'\w\-]+",
44                          re.UNICODE)
45
46
47@deprecated(None, 'Please use tensorflow/transform or tf.data.')
48def tokenizer(iterator):
49  """Tokenizer generator.
50
51  Args:
52    iterator: Input iterator with strings.
53
54  Yields:
55    array of tokens per each value in the input.
56  """
57  for value in iterator:
58    yield TOKENIZER_RE.findall(value)
59
60
61@deprecated(None, 'Please use tensorflow/transform or tf.data.')
62class ByteProcessor(object):
63  """Maps documents into sequence of ids for bytes.
64
65  THIS CLASS IS DEPRECATED. See
66  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
67  for general migration instructions.
68  """
69
70  @deprecated(None, 'Please use tensorflow/transform or tf.data.')
71  def __init__(self, max_document_length):
72    self.max_document_length = max_document_length
73
74  def fit(self, x):
75    """Does nothing. No fitting required."""
76    pass
77
78  def fit_transform(self, x):
79    """Calls transform."""
80    return self.transform(x)
81
82  # pylint: disable=no-self-use
83  def reverse(self, x):
84    """Reverses output of transform back to text.
85
86    Args:
87      x: iterator or matrix of integers. Document representation in bytes.
88
89    Yields:
90      Iterators of utf-8 strings.
91    """
92    for data in x:
93      document = np.trim_zeros(data.astype(np.int8), trim='b').tostring()
94      try:
95        yield document.decode('utf-8')
96      except UnicodeDecodeError:
97        yield ''
98
99  def transform(self, x):
100    """Transforms input documents into sequence of ids.
101
102    Args:
103      x: iterator or list of input documents.
104        Documents can be bytes or unicode strings, which will be encoded as
105        utf-8 to map to bytes. Note, in Python2 str and bytes is the same type.
106
107    Yields:
108      iterator of byte ids.
109    """
110    if six.PY3:
111      # For Python3 defined buffer as memoryview.
112      buffer_or_memoryview = memoryview
113    else:
114      buffer_or_memoryview = buffer  # pylint: disable=undefined-variable
115    for document in x:
116      if isinstance(document, six.text_type):
117        document = document.encode('utf-8')
118      document_mv = buffer_or_memoryview(document)
119      buff = np.frombuffer(document_mv[:self.max_document_length],
120                           dtype=np.uint8)
121      yield np.pad(buff, (0, self.max_document_length - len(buff)), 'constant')
122
123
124class VocabularyProcessor(object):
125  """Maps documents to sequences of word ids.
126
127  THIS CLASS IS DEPRECATED. See
128  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
129  for general migration instructions.
130  """
131
132  @deprecated(None, 'Please use tensorflow/transform or tf.data.')
133  def __init__(self,
134               max_document_length,
135               min_frequency=0,
136               vocabulary=None,
137               tokenizer_fn=None):
138    """Initializes a VocabularyProcessor instance.
139
140    Args:
141      max_document_length: Maximum length of documents.
142        if documents are longer, they will be trimmed, if shorter - padded.
143      min_frequency: Minimum frequency of words in the vocabulary.
144      vocabulary: CategoricalVocabulary object.
145
146    Attributes:
147      vocabulary_: CategoricalVocabulary object.
148    """
149    self.max_document_length = max_document_length
150    self.min_frequency = min_frequency
151    if vocabulary:
152      self.vocabulary_ = vocabulary
153    else:
154      self.vocabulary_ = CategoricalVocabulary()
155    if tokenizer_fn:
156      self._tokenizer = tokenizer_fn
157    else:
158      self._tokenizer = tokenizer
159
160  def fit(self, raw_documents, unused_y=None):
161    """Learn a vocabulary dictionary of all tokens in the raw documents.
162
163    Args:
164      raw_documents: An iterable which yield either str or unicode.
165      unused_y: to match fit format signature of estimators.
166
167    Returns:
168      self
169    """
170    for tokens in self._tokenizer(raw_documents):
171      for token in tokens:
172        self.vocabulary_.add(token)
173    if self.min_frequency > 0:
174      self.vocabulary_.trim(self.min_frequency)
175    self.vocabulary_.freeze()
176    return self
177
178  def fit_transform(self, raw_documents, unused_y=None):
179    """Learn the vocabulary dictionary and return indexies of words.
180
181    Args:
182      raw_documents: An iterable which yield either str or unicode.
183      unused_y: to match fit_transform signature of estimators.
184
185    Returns:
186      x: iterable, [n_samples, max_document_length]. Word-id matrix.
187    """
188    self.fit(raw_documents)
189    return self.transform(raw_documents)
190
191  def transform(self, raw_documents):
192    """Transform documents to word-id matrix.
193
194    Convert words to ids with vocabulary fitted with fit or the one
195    provided in the constructor.
196
197    Args:
198      raw_documents: An iterable which yield either str or unicode.
199
200    Yields:
201      x: iterable, [n_samples, max_document_length]. Word-id matrix.
202    """
203    for tokens in self._tokenizer(raw_documents):
204      word_ids = np.zeros(self.max_document_length, np.int64)
205      for idx, token in enumerate(tokens):
206        if idx >= self.max_document_length:
207          break
208        word_ids[idx] = self.vocabulary_.get(token)
209      yield word_ids
210
211  def reverse(self, documents):
212    """Reverses output of vocabulary mapping to words.
213
214    Args:
215      documents: iterable, list of class ids.
216
217    Yields:
218      Iterator over mapped in words documents.
219    """
220    for item in documents:
221      output = []
222      for class_id in item:
223        output.append(self.vocabulary_.reverse(class_id))
224      yield ' '.join(output)
225
226  def save(self, filename):
227    """Saves vocabulary processor into given file.
228
229    Args:
230      filename: Path to output file.
231    """
232    with gfile.Open(filename, 'wb') as f:
233      f.write(pickle.dumps(self))
234
235  @classmethod
236  def restore(cls, filename):
237    """Restores vocabulary processor from given file.
238
239    Args:
240      filename: Path to file to load from.
241
242    Returns:
243      VocabularyProcessor object.
244    """
245    with gfile.Open(filename, 'rb') as f:
246      return pickle.loads(f.read())
247