1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for LogisticRegressor.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.contrib import layers 24from tensorflow.python.training import training_util 25from tensorflow.contrib.layers.python.layers import optimizers 26from tensorflow.contrib.learn.python.learn.datasets import base 27from tensorflow.contrib.learn.python.learn.estimators import logistic_regressor 28from tensorflow.contrib.learn.python.learn.estimators import metric_key 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import init_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops.losses import losses 35from tensorflow.python.platform import test 36 37 38def _iris_data_input_fn(): 39 # Converts iris data to a logistic regression problem. 40 iris = base.load_iris() 41 ids = np.where((iris.target == 0) | (iris.target == 1)) 42 features = constant_op.constant(iris.data[ids], dtype=dtypes.float32) 43 labels = constant_op.constant(iris.target[ids], dtype=dtypes.float32) 44 labels = array_ops.reshape(labels, labels.get_shape().concatenate(1)) 45 return features, labels 46 47 48def _logistic_regression_model_fn(features, labels, mode): 49 _ = mode 50 logits = layers.linear( 51 features, 52 1, 53 weights_initializer=init_ops.zeros_initializer(), 54 # Intentionally uses really awful initial values so that 55 # AUC/precision/recall/etc will change meaningfully even on a toy dataset. 56 biases_initializer=init_ops.constant_initializer(-10.0)) 57 predictions = math_ops.sigmoid(logits) 58 loss = losses.sigmoid_cross_entropy(labels, logits) 59 train_op = optimizers.optimize_loss( 60 loss, 61 training_util.get_global_step(), 62 optimizer='Adagrad', 63 learning_rate=0.1) 64 return predictions, loss, train_op 65 66 67class LogisticRegressorTest(test.TestCase): 68 69 def test_fit_and_evaluate_metrics(self): 70 """Tests basic fit and evaluate, and checks the evaluation metrics.""" 71 regressor = logistic_regressor.LogisticRegressor( 72 model_fn=_logistic_regression_model_fn) 73 74 # Get some (intentionally horrible) baseline metrics. 75 regressor.fit(input_fn=_iris_data_input_fn, steps=1) 76 eval_metrics = regressor.evaluate(input_fn=_iris_data_input_fn, steps=1) 77 self.assertNear( 78 0.0, eval_metrics[metric_key.MetricKey.PREDICTION_MEAN], err=1e-3) 79 self.assertNear( 80 0.5, eval_metrics[metric_key.MetricKey.LABEL_MEAN], err=1e-6) 81 self.assertNear( 82 0.5, eval_metrics[metric_key.MetricKey.ACCURACY_BASELINE], err=1e-6) 83 self.assertNear(0.5, eval_metrics[metric_key.MetricKey.AUC], err=1e-6) 84 self.assertNear( 85 0.5, eval_metrics[metric_key.MetricKey.ACCURACY_MEAN % 0.5], err=1e-6) 86 self.assertNear( 87 0.0, eval_metrics[metric_key.MetricKey.PRECISION_MEAN % 0.5], err=1e-6) 88 self.assertNear( 89 0.0, eval_metrics[metric_key.MetricKey.RECALL_MEAN % 0.5], err=1e-6) 90 91 # Train for more steps and check the metrics again. 92 regressor.fit(input_fn=_iris_data_input_fn, steps=100) 93 eval_metrics = regressor.evaluate(input_fn=_iris_data_input_fn, steps=1) 94 # Mean prediction moves from ~0.0 to ~0.5 as we stop predicting all 0's. 95 self.assertNear( 96 0.5, eval_metrics[metric_key.MetricKey.PREDICTION_MEAN], err=1e-2) 97 # Label mean and baseline both remain the same at 0.5. 98 self.assertNear( 99 0.5, eval_metrics[metric_key.MetricKey.LABEL_MEAN], err=1e-6) 100 self.assertNear( 101 0.5, eval_metrics[metric_key.MetricKey.ACCURACY_BASELINE], err=1e-6) 102 # AUC improves from 0.5 to 1.0. 103 self.assertNear(1.0, eval_metrics[metric_key.MetricKey.AUC], err=1e-6) 104 # Accuracy improves from 0.5 to >0.9. 105 self.assertTrue( 106 eval_metrics[metric_key.MetricKey.ACCURACY_MEAN % 0.5] > 0.9) 107 # Precision improves from 0.0 to 1.0. 108 self.assertNear( 109 1.0, eval_metrics[metric_key.MetricKey.PRECISION_MEAN % 0.5], err=1e-6) 110 # Recall improves from 0.0 to >0.9. 111 self.assertTrue(eval_metrics[metric_key.MetricKey.RECALL_MEAN % 0.5] > 0.9) 112 113 114if __name__ == '__main__': 115 test.main() 116