• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Estimators that combine explicit kernel mappings with linear models."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import six
22
23from tensorflow.contrib import layers
24from tensorflow.contrib.kernel_methods.python.mappers import dense_kernel_mapper as dkm
25from tensorflow.contrib.learn.python.learn.estimators import estimator
26from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
27from tensorflow.contrib.learn.python.learn.estimators import linear
28from tensorflow.contrib.learn.python.learn.estimators import prediction_key
29from tensorflow.python.ops import array_ops
30from tensorflow.python.platform import tf_logging as logging
31
32_FEATURE_COLUMNS = "feature_columns"
33_KERNEL_MAPPERS = "kernel_mappers"
34_OPTIMIZER = "optimizer"
35
36
37def _check_valid_kernel_mappers(kernel_mappers):
38  """Checks that the input kernel_mappers are valid."""
39  if kernel_mappers is None:
40    return True
41  for kernel_mappers_list in six.itervalues(kernel_mappers):
42    for kernel_mapper in kernel_mappers_list:
43      if not isinstance(kernel_mapper, dkm.DenseKernelMapper):
44        return False
45  return True
46
47
48def _check_valid_head(head):
49  """Returns true if the provided head is supported."""
50  if head is None:
51    return False
52  # pylint: disable=protected-access
53  return isinstance(head, head_lib._BinaryLogisticHead) or isinstance(
54      head, head_lib._MultiClassHead)
55  # pylint: enable=protected-access
56
57
58def _update_features_and_columns(features, feature_columns,
59                                 kernel_mappers_dict):
60  """Updates features and feature_columns based on provided kernel mappers.
61
62  Currently supports the update of `RealValuedColumn`s only.
63
64  Args:
65    features: Initial features dict. The key is a `string` (feature column name)
66      and the value is a tensor.
67    feature_columns: Initial iterable containing all the feature columns to be
68      consumed (possibly after being updated) by the model. All items should be
69      instances of classes derived from `FeatureColumn`.
70    kernel_mappers_dict: A dict from feature column (type: _FeatureColumn) to
71      objects inheriting from KernelMapper class.
72
73  Returns:
74    updated features and feature_columns based on provided kernel_mappers_dict.
75  """
76  if kernel_mappers_dict is None:
77    return features, feature_columns
78
79  # First construct new columns and features affected by kernel_mappers_dict.
80  mapped_features = dict()
81  mapped_columns = set()
82  for feature_column in kernel_mappers_dict:
83    column_name = feature_column.name
84    # Currently only mappings over RealValuedColumns are supported.
85    if not isinstance(feature_column, layers.feature_column._RealValuedColumn):  # pylint: disable=protected-access
86      logging.warning(
87          "Updates are currently supported on RealValuedColumns only. Metadata "
88          "for FeatureColumn {} will not be updated.".format(column_name))
89      continue
90    mapped_column_name = column_name + "_MAPPED"
91    # Construct new feature columns based on provided kernel_mappers.
92    column_kernel_mappers = kernel_mappers_dict[feature_column]
93    new_dim = sum(mapper.output_dim for mapper in column_kernel_mappers)
94    mapped_columns.add(
95        layers.feature_column.real_valued_column(mapped_column_name, new_dim))
96
97    # Get mapped features by concatenating mapped tensors (one mapped tensor
98    # per kernel mappers from the list of kernel mappers corresponding to each
99    # feature column).
100    output_tensors = []
101    for kernel_mapper in column_kernel_mappers:
102      output_tensors.append(kernel_mapper.map(features[column_name]))
103    tensor = array_ops.concat(output_tensors, 1)
104    mapped_features[mapped_column_name] = tensor
105
106  # Finally update features dict and feature_columns.
107  features = features.copy()
108  features.update(mapped_features)
109  feature_columns = set(feature_columns)
110  feature_columns.update(mapped_columns)
111
112  return features, feature_columns
113
114
115def _kernel_model_fn(features, labels, mode, params, config=None):
116  """model_fn for the Estimator using kernel methods.
117
118  Args:
119    features: `Tensor` or dict of `Tensor` (depends on data passed to `fit`).
120    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
121      dtype `int32` or `int64` in the range `[0, n_classes)`.
122    mode: Defines whether this is training, evaluation or prediction. See
123      `ModeKeys`.
124    params: A dict of hyperparameters.
125      The following hyperparameters are expected:
126      * head: A `Head` instance.
127      * feature_columns: An iterable containing all the feature columns used by
128          the model.
129      * optimizer: string, `Optimizer` object, or callable that defines the
130          optimizer to use for training. If `None`, will use a FTRL optimizer.
131      * kernel_mappers: Dictionary of kernel mappers to be applied to the input
132          features before training.
133    config: `RunConfig` object to configure the runtime settings.
134
135  Returns:
136    A `ModelFnOps` instance.
137
138  Raises:
139    ValueError: If mode is not any of the `ModeKeys`.
140  """
141  feature_columns = params[_FEATURE_COLUMNS]
142  kernel_mappers = params[_KERNEL_MAPPERS]
143
144  updated_features, updated_columns = _update_features_and_columns(
145      features, feature_columns, kernel_mappers)
146  params[_FEATURE_COLUMNS] = updated_columns
147
148  return linear._linear_model_fn(  # pylint: disable=protected-access
149      updated_features, labels, mode, params, config)
150
151
152class _KernelEstimator(estimator.Estimator):
153  """Generic kernel-based linear estimator."""
154
155  def __init__(self,
156               feature_columns=None,
157               model_dir=None,
158               weight_column_name=None,
159               head=None,
160               optimizer=None,
161               kernel_mappers=None,
162               config=None):
163    """Constructs a `_KernelEstimator` object."""
164    if not feature_columns and not kernel_mappers:
165      raise ValueError(
166          "You should set at least one of feature_columns, kernel_mappers.")
167    if not _check_valid_kernel_mappers(kernel_mappers):
168      raise ValueError("Invalid kernel mappers.")
169
170    if not _check_valid_head(head):
171      raise ValueError(
172          "head type: {} is not supported. Supported head types: "
173          "_BinaryLogisticHead, _MultiClassHead.".format(type(head)))
174
175    params = {
176        "head": head,
177        _FEATURE_COLUMNS: feature_columns or [],
178        _OPTIMIZER: optimizer,
179        _KERNEL_MAPPERS: kernel_mappers,
180    }
181    super(_KernelEstimator, self).__init__(
182        model_fn=_kernel_model_fn,
183        model_dir=model_dir,
184        config=config,
185        params=params)
186
187
188class KernelLinearClassifier(_KernelEstimator):
189  """Linear classifier using kernel methods as feature preprocessing.
190
191  It trains a linear model after possibly mapping initial input features into
192  a mapped space using explicit kernel mappings. Due to the kernel mappings,
193  training a linear classifier in the mapped (output) space can detect
194  non-linearities in the input space.
195
196  The user can provide a list of kernel mappers to be applied to all or a subset
197  of existing feature_columns. This way, the user can effectively provide 2
198  types of feature columns:
199
200  * those passed as elements of feature_columns in the classifier's constructor
201  * those appearing as a key of the kernel_mappers dict.
202
203  If a column appears in feature_columns only, no mapping is applied to it. If
204  it appears as a key in kernel_mappers, the corresponding kernel mappers are
205  applied to it. Note that it is possible that a column appears in both places.
206  Currently kernel_mappers are supported for _RealValuedColumns only.
207
208  Example usage:
209  ```
210  real_column_a = real_valued_column(name='real_column_a',...)
211  sparse_column_b = sparse_column_with_hash_bucket(...)
212  kernel_mappers = {real_column_a : [RandomFourierFeatureMapper(...)]}
213  optimizer = ...
214
215  # real_column_a is used as a feature in both its initial and its transformed
216  # (mapped) form. sparse_column_b is not affected by kernel mappers.
217  kernel_classifier = KernelLinearClassifier(
218      feature_columns=[real_column_a, sparse_column_b],
219      model_dir=...,
220      optimizer=optimizer,
221      kernel_mappers=kernel_mappers)
222
223  # real_column_a is used as a feature in its transformed (mapped) form only.
224  # sparse_column_b is not affected by kernel mappers.
225  kernel_classifier = KernelLinearClassifier(
226      feature_columns=[sparse_column_b],
227      model_dir=...,
228      optimizer=optimizer,
229      kernel_mappers=kernel_mappers)
230
231  # Input builders
232  def train_input_fn: # returns x, y
233    ...
234  def eval_input_fn: # returns x, y
235    ...
236
237  kernel_classifier.fit(input_fn=train_input_fn)
238  kernel_classifier.evaluate(input_fn=eval_input_fn)
239  kernel_classifier.predict(...)
240  ```
241
242  Input of `fit` and `evaluate` should have following features, otherwise there
243  will be a `KeyError`:
244
245  * if `weight_column_name` is not `None`, a feature with
246    `key=weight_column_name` whose value is a `Tensor`.
247  * for each `column` in `feature_columns`:
248    - if `column` is a `SparseColumn`, a feature with `key=column.name`
249      whose `value` is a `SparseTensor`.
250    - if `column` is a `WeightedSparseColumn`, two features: the first with
251      `key` the id column name, the second with `key` the weight column name.
252      Both features' `value` must be a `SparseTensor`.
253    - if `column` is a `RealValuedColumn`, a feature with `key=column.name`
254      whose `value` is a `Tensor`.
255  """
256
257  def __init__(self,
258               feature_columns=None,
259               model_dir=None,
260               n_classes=2,
261               weight_column_name=None,
262               optimizer=None,
263               kernel_mappers=None,
264               config=None):
265    """Construct a `KernelLinearClassifier` estimator object.
266
267    Args:
268      feature_columns: An iterable containing all the feature columns used by
269        the model. All items in the set should be instances of classes derived
270        from `FeatureColumn`.
271      model_dir: Directory to save model parameters, graph etc. This can also be
272        used to load checkpoints from the directory into an estimator to
273        continue training a previously saved model.
274      n_classes: number of label classes. Default is binary classification.
275        Note that class labels are integers representing the class index (i.e.
276        values from 0 to n_classes-1). For arbitrary label values (e.g. string
277        labels), convert to class indices first.
278      weight_column_name: A string defining feature column name representing
279        weights. It is used to down weight or boost examples during training. It
280        will be multiplied by the loss of the example.
281      optimizer: The optimizer used to train the model. If specified, it should
282        be an instance of `tf.Optimizer`. If `None`, the Ftrl optimizer is used
283        by default.
284      kernel_mappers: Dictionary of kernel mappers to be applied to the input
285        features before training a (linear) model. Keys are feature columns and
286        values are lists of mappers to be applied to the corresponding feature
287        column. Currently only _RealValuedColumns are supported and therefore
288        all mappers should conform to the `DenseKernelMapper` interface (see
289        ./mappers/dense_kernel_mapper.py).
290      config: `RunConfig` object to configure the runtime settings.
291
292    Returns:
293      A `KernelLinearClassifier` estimator.
294
295    Raises:
296      ValueError: if n_classes < 2.
297      ValueError: if neither feature_columns nor kernel_mappers are provided.
298      ValueError: if mappers provided as kernel_mappers values are invalid.
299    """
300    super(KernelLinearClassifier, self).__init__(
301        feature_columns=feature_columns,
302        model_dir=model_dir,
303        weight_column_name=weight_column_name,
304        head=head_lib.multi_class_head(
305            n_classes=n_classes, weight_column_name=weight_column_name),
306        optimizer=optimizer,
307        kernel_mappers=kernel_mappers,
308        config=config)
309
310  def predict_classes(self, input_fn=None):
311    """Runs inference to determine the predicted class per instance.
312
313    Args:
314      input_fn: The input function providing features.
315
316    Returns:
317      A generator of predicted classes for the features provided by input_fn.
318      Each predicted class is represented by its class index (i.e. integer from
319      0 to n_classes-1)
320    """
321    key = prediction_key.PredictionKey.CLASSES
322    predictions = super(KernelLinearClassifier, self).predict(
323        input_fn=input_fn, outputs=[key])
324    return (pred[key] for pred in predictions)
325
326  def predict_proba(self, input_fn=None):
327    """Runs inference to determine the class probability predictions.
328
329    Args:
330      input_fn: The input function providing features.
331
332    Returns:
333      A generator of predicted class probabilities for the features provided by
334        input_fn.
335    """
336    key = prediction_key.PredictionKey.PROBABILITIES
337    predictions = super(KernelLinearClassifier, self).predict(
338        input_fn=input_fn, outputs=[key])
339    return (pred[key] for pred in predictions)
340