• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Implements preprocessing transformers for categorical variables (deprecated).
17
18This module and all its submodules are deprecated. See
19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
20for migration instructions.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import math
28import numpy as np
29
30from tensorflow.python.util.deprecation import deprecated
31
32# pylint: disable=g-bad-import-order
33from . import categorical_vocabulary
34from ..learn_io.data_feeder import setup_processor_data_feeder
35# pylint: enable=g-bad-import-order
36
37
38class CategoricalProcessor(object):
39  """Maps documents to sequences of word ids.
40
41  THIS CLASS IS DEPRECATED. See
42  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
43  for general migration instructions.
44
45  As a common convention, Nan values are handled as unknown tokens.
46  Both float('nan') and np.nan are accepted.
47  """
48
49  @deprecated(None, 'Please use tensorflow/transform or tf.data for sequence '
50              'processing.')
51  def __init__(self, min_frequency=0, share=False, vocabularies=None):
52    """Initializes a CategoricalProcessor instance.
53
54    Args:
55      min_frequency: Minimum frequency of categories in the vocabulary.
56      share: Share vocabulary between variables.
57      vocabularies: list of CategoricalVocabulary objects for each variable in
58        the input dataset.
59
60    Attributes:
61      vocabularies_: list of CategoricalVocabulary objects.
62    """
63    self.min_frequency = min_frequency
64    self.share = share
65    self.vocabularies_ = vocabularies
66
67  def freeze(self, freeze=True):
68    """Freeze or unfreeze all vocabularies.
69
70    Args:
71      freeze: Boolean, indicate if vocabularies should be frozen.
72    """
73    for vocab in self.vocabularies_:
74      vocab.freeze(freeze)
75
76  def fit(self, x, unused_y=None):
77    """Learn a vocabulary dictionary of all categories in `x`.
78
79    Args:
80      x: numpy matrix or iterable of lists/numpy arrays.
81      unused_y: to match fit format signature of estimators.
82
83    Returns:
84      self
85    """
86    x = setup_processor_data_feeder(x)
87    for row in x:
88      # Create vocabularies if not given.
89      if self.vocabularies_ is None:
90        # If not share, one per column, else one shared across.
91        if not self.share:
92          self.vocabularies_ = [
93              categorical_vocabulary.CategoricalVocabulary() for _ in row
94          ]
95        else:
96          vocab = categorical_vocabulary.CategoricalVocabulary()
97          self.vocabularies_ = [vocab for _ in row]
98      for idx, value in enumerate(row):
99        # Nans are handled as unknowns.
100        if (isinstance(value, float) and math.isnan(value)) or value == np.nan:
101          continue
102        self.vocabularies_[idx].add(value)
103    if self.min_frequency > 0:
104      for vocab in self.vocabularies_:
105        vocab.trim(self.min_frequency)
106    self.freeze()
107    return self
108
109  def fit_transform(self, x, unused_y=None):
110    """Learn the vocabulary dictionary and return indexies of categories.
111
112    Args:
113      x: numpy matrix or iterable of lists/numpy arrays.
114      unused_y: to match fit_transform signature of estimators.
115
116    Returns:
117      x: iterable, [n_samples]. Category-id matrix.
118    """
119    self.fit(x)
120    return self.transform(x)
121
122  def transform(self, x):
123    """Transform documents to category-id matrix.
124
125    Converts categories to ids give fitted vocabulary from `fit` or
126    one provided in the constructor.
127
128    Args:
129      x: numpy matrix or iterable of lists/numpy arrays.
130
131    Yields:
132      x: iterable, [n_samples]. Category-id matrix.
133    """
134    self.freeze()
135    x = setup_processor_data_feeder(x)
136    for row in x:
137      output_row = []
138      for idx, value in enumerate(row):
139        # Return <UNK> when it's Nan.
140        if (isinstance(value, float) and math.isnan(value)) or value == np.nan:
141          output_row.append(0)
142          continue
143        output_row.append(self.vocabularies_[idx].get(value))
144      yield np.array(output_row, dtype=np.int64)
145