1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Logistic regression (aka binary classifier) class (deprecated). 16 17This module and all its submodules are deprecated. See 18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 19for migration instructions. 20 21This defines some useful basic metrics for using logistic regression to classify 22a binary event (0 vs 1). 23""" 24 25from __future__ import absolute_import 26from __future__ import division 27from __future__ import print_function 28 29from tensorflow.contrib import metrics as metrics_lib 30from tensorflow.contrib.learn.python.learn.estimators import constants 31from tensorflow.contrib.learn.python.learn.estimators import estimator 32from tensorflow.contrib.learn.python.learn.estimators import metric_key 33from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib 34from tensorflow.python.framework import dtypes 35from tensorflow.python.ops import math_ops 36 37 38def _get_model_fn_with_logistic_metrics(model_fn): 39 """Returns a model_fn with additional logistic metrics. 40 41 Args: 42 model_fn: Model function with the signature: 43 `(features, labels, mode) -> (predictions, loss, train_op)`. 44 Expects the returned predictions to be probabilities in [0.0, 1.0]. 45 46 Returns: 47 model_fn that can be used with Estimator. 48 """ 49 50 def _model_fn(features, labels, mode, params): 51 """Model function that appends logistic evaluation metrics.""" 52 thresholds = params.get('thresholds') or [.5] 53 54 predictions, loss, train_op = model_fn(features, labels, mode) 55 if mode == model_fn_lib.ModeKeys.EVAL: 56 eval_metric_ops = _make_logistic_eval_metric_ops( 57 labels=labels, 58 predictions=predictions, 59 thresholds=thresholds) 60 else: 61 eval_metric_ops = None 62 return model_fn_lib.ModelFnOps( 63 mode=mode, 64 predictions=predictions, 65 loss=loss, 66 train_op=train_op, 67 eval_metric_ops=eval_metric_ops, 68 output_alternatives={ 69 'head': (constants.ProblemType.LOGISTIC_REGRESSION, { 70 'predictions': predictions 71 }) 72 }) 73 74 return _model_fn 75 76 77# TODO(roumposg): Deprecate and delete after converting users to use head. 78def LogisticRegressor( # pylint: disable=invalid-name 79 model_fn, thresholds=None, model_dir=None, config=None, 80 feature_engineering_fn=None): 81 """Builds a logistic regression Estimator for binary classification. 82 83 THIS CLASS IS DEPRECATED. See 84 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 85 for general migration instructions. 86 87 This method provides a basic Estimator with some additional metrics for custom 88 binary classification models, including AUC, precision/recall and accuracy. 89 90 Example: 91 92 ```python 93 # See tf.contrib.learn.Estimator(...) for details on model_fn structure 94 def my_model_fn(...): 95 pass 96 97 estimator = LogisticRegressor(model_fn=my_model_fn) 98 99 # Input builders 100 def input_fn_train: 101 pass 102 103 estimator.fit(input_fn=input_fn_train) 104 estimator.predict(x=x) 105 ``` 106 107 Args: 108 model_fn: Model function with the signature: 109 `(features, labels, mode) -> (predictions, loss, train_op)`. 110 Expects the returned predictions to be probabilities in [0.0, 1.0]. 111 thresholds: List of floating point thresholds to use for accuracy, 112 precision, and recall metrics. If `None`, defaults to `[0.5]`. 113 model_dir: Directory to save model parameters, graphs, etc. This can also 114 be used to load checkpoints from the directory into a estimator to 115 continue training a previously saved model. 116 config: A RunConfig configuration object. 117 feature_engineering_fn: Feature engineering function. Takes features and 118 labels which are the output of `input_fn` and 119 returns features and labels which will be fed 120 into the model. 121 122 Returns: 123 An `Estimator` instance. 124 """ 125 return estimator.Estimator( 126 model_fn=_get_model_fn_with_logistic_metrics(model_fn), 127 model_dir=model_dir, 128 config=config, 129 params={'thresholds': thresholds}, 130 feature_engineering_fn=feature_engineering_fn) 131 132 133def _make_logistic_eval_metric_ops(labels, predictions, thresholds): 134 """Returns a dictionary of evaluation metric ops for logistic regression. 135 136 Args: 137 labels: The labels `Tensor`, or a dict with only one `Tensor` keyed by name. 138 predictions: The predictions `Tensor`. 139 thresholds: List of floating point thresholds to use for accuracy, 140 precision, and recall metrics. 141 142 Returns: 143 A dict of metric results keyed by name. 144 """ 145 # If labels is a dict with a single key, unpack into a single tensor. 146 labels_tensor = labels 147 if isinstance(labels, dict) and len(labels) == 1: 148 labels_tensor = labels.values()[0] 149 150 metrics = {} 151 metrics[metric_key.MetricKey.PREDICTION_MEAN] = metrics_lib.streaming_mean( 152 predictions) 153 metrics[metric_key.MetricKey.LABEL_MEAN] = metrics_lib.streaming_mean( 154 labels_tensor) 155 # Also include the streaming mean of the label as an accuracy baseline, as 156 # a reminder to users. 157 metrics[metric_key.MetricKey.ACCURACY_BASELINE] = metrics_lib.streaming_mean( 158 labels_tensor) 159 160 metrics[metric_key.MetricKey.AUC] = metrics_lib.streaming_auc( 161 labels=labels_tensor, predictions=predictions) 162 163 for threshold in thresholds: 164 predictions_at_threshold = math_ops.cast( 165 math_ops.greater_equal(predictions, threshold), 166 dtypes.float32, 167 name='predictions_at_threshold_%f' % threshold) 168 metrics[metric_key.MetricKey.ACCURACY_MEAN % threshold] = ( 169 metrics_lib.streaming_accuracy(labels=labels_tensor, 170 predictions=predictions_at_threshold)) 171 # Precision for positive examples. 172 metrics[metric_key.MetricKey.PRECISION_MEAN % threshold] = ( 173 metrics_lib.streaming_precision(labels=labels_tensor, 174 predictions=predictions_at_threshold)) 175 # Recall for positive examples. 176 metrics[metric_key.MetricKey.RECALL_MEAN % threshold] = ( 177 metrics_lib.streaming_recall(labels=labels_tensor, 178 predictions=predictions_at_threshold)) 179 180 return metrics 181