• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Logistic regression (aka binary classifier) class (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20
21This defines some useful basic metrics for using logistic regression to classify
22a binary event (0 vs 1).
23"""
24
25from __future__ import absolute_import
26from __future__ import division
27from __future__ import print_function
28
29from tensorflow.contrib import metrics as metrics_lib
30from tensorflow.contrib.learn.python.learn.estimators import constants
31from tensorflow.contrib.learn.python.learn.estimators import estimator
32from tensorflow.contrib.learn.python.learn.estimators import metric_key
33from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
34from tensorflow.python.framework import dtypes
35from tensorflow.python.ops import math_ops
36
37
38def _get_model_fn_with_logistic_metrics(model_fn):
39  """Returns a model_fn with additional logistic metrics.
40
41  Args:
42    model_fn: Model function with the signature:
43      `(features, labels, mode) -> (predictions, loss, train_op)`.
44      Expects the returned predictions to be probabilities in [0.0, 1.0].
45
46  Returns:
47    model_fn that can be used with Estimator.
48  """
49
50  def _model_fn(features, labels, mode, params):
51    """Model function that appends logistic evaluation metrics."""
52    thresholds = params.get('thresholds') or [.5]
53
54    predictions, loss, train_op = model_fn(features, labels, mode)
55    if mode == model_fn_lib.ModeKeys.EVAL:
56      eval_metric_ops = _make_logistic_eval_metric_ops(
57          labels=labels,
58          predictions=predictions,
59          thresholds=thresholds)
60    else:
61      eval_metric_ops = None
62    return model_fn_lib.ModelFnOps(
63        mode=mode,
64        predictions=predictions,
65        loss=loss,
66        train_op=train_op,
67        eval_metric_ops=eval_metric_ops,
68        output_alternatives={
69            'head': (constants.ProblemType.LOGISTIC_REGRESSION, {
70                'predictions': predictions
71            })
72        })
73
74  return _model_fn
75
76
77# TODO(roumposg): Deprecate and delete after converting users to use head.
78def LogisticRegressor(  # pylint: disable=invalid-name
79    model_fn, thresholds=None, model_dir=None, config=None,
80    feature_engineering_fn=None):
81  """Builds a logistic regression Estimator for binary classification.
82
83  THIS CLASS IS DEPRECATED. See
84  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
85  for general migration instructions.
86
87  This method provides a basic Estimator with some additional metrics for custom
88  binary classification models, including AUC, precision/recall and accuracy.
89
90  Example:
91
92  ```python
93    # See tf.contrib.learn.Estimator(...) for details on model_fn structure
94    def my_model_fn(...):
95      pass
96
97    estimator = LogisticRegressor(model_fn=my_model_fn)
98
99    # Input builders
100    def input_fn_train:
101      pass
102
103    estimator.fit(input_fn=input_fn_train)
104    estimator.predict(x=x)
105  ```
106
107  Args:
108    model_fn: Model function with the signature:
109      `(features, labels, mode) -> (predictions, loss, train_op)`.
110      Expects the returned predictions to be probabilities in [0.0, 1.0].
111    thresholds: List of floating point thresholds to use for accuracy,
112      precision, and recall metrics. If `None`, defaults to `[0.5]`.
113    model_dir: Directory to save model parameters, graphs, etc. This can also
114      be used to load checkpoints from the directory into a estimator to
115      continue training a previously saved model.
116    config: A RunConfig configuration object.
117    feature_engineering_fn: Feature engineering function. Takes features and
118                      labels which are the output of `input_fn` and
119                      returns features and labels which will be fed
120                      into the model.
121
122  Returns:
123    An `Estimator` instance.
124  """
125  return estimator.Estimator(
126      model_fn=_get_model_fn_with_logistic_metrics(model_fn),
127      model_dir=model_dir,
128      config=config,
129      params={'thresholds': thresholds},
130      feature_engineering_fn=feature_engineering_fn)
131
132
133def _make_logistic_eval_metric_ops(labels, predictions, thresholds):
134  """Returns a dictionary of evaluation metric ops for logistic regression.
135
136  Args:
137    labels: The labels `Tensor`, or a dict with only one `Tensor` keyed by name.
138    predictions: The predictions `Tensor`.
139    thresholds: List of floating point thresholds to use for accuracy,
140      precision, and recall metrics.
141
142  Returns:
143    A dict of metric results keyed by name.
144  """
145  # If labels is a dict with a single key, unpack into a single tensor.
146  labels_tensor = labels
147  if isinstance(labels, dict) and len(labels) == 1:
148    labels_tensor = labels.values()[0]
149
150  metrics = {}
151  metrics[metric_key.MetricKey.PREDICTION_MEAN] = metrics_lib.streaming_mean(
152      predictions)
153  metrics[metric_key.MetricKey.LABEL_MEAN] = metrics_lib.streaming_mean(
154      labels_tensor)
155  # Also include the streaming mean of the label as an accuracy baseline, as
156  # a reminder to users.
157  metrics[metric_key.MetricKey.ACCURACY_BASELINE] = metrics_lib.streaming_mean(
158      labels_tensor)
159
160  metrics[metric_key.MetricKey.AUC] = metrics_lib.streaming_auc(
161      labels=labels_tensor, predictions=predictions)
162
163  for threshold in thresholds:
164    predictions_at_threshold = math_ops.cast(
165        math_ops.greater_equal(predictions, threshold),
166        dtypes.float32,
167        name='predictions_at_threshold_%f' % threshold)
168    metrics[metric_key.MetricKey.ACCURACY_MEAN % threshold] = (
169        metrics_lib.streaming_accuracy(labels=labels_tensor,
170                                       predictions=predictions_at_threshold))
171    # Precision for positive examples.
172    metrics[metric_key.MetricKey.PRECISION_MEAN % threshold] = (
173        metrics_lib.streaming_precision(labels=labels_tensor,
174                                        predictions=predictions_at_threshold))
175    # Recall for positive examples.
176    metrics[metric_key.MetricKey.RECALL_MEAN % threshold] = (
177        metrics_lib.streaming_recall(labels=labels_tensor,
178                                     predictions=predictions_at_threshold))
179
180  return metrics
181