• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for LogisticRegressor."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.contrib import layers
24from tensorflow.python.training import training_util
25from tensorflow.contrib.layers.python.layers import optimizers
26from tensorflow.contrib.learn.python.learn.datasets import base
27from tensorflow.contrib.learn.python.learn.estimators import logistic_regressor
28from tensorflow.contrib.learn.python.learn.estimators import metric_key
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import init_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops.losses import losses
35from tensorflow.python.platform import test
36
37
38def _iris_data_input_fn():
39  # Converts iris data to a logistic regression problem.
40  iris = base.load_iris()
41  ids = np.where((iris.target == 0) | (iris.target == 1))
42  features = constant_op.constant(iris.data[ids], dtype=dtypes.float32)
43  labels = constant_op.constant(iris.target[ids], dtype=dtypes.float32)
44  labels = array_ops.reshape(labels, labels.get_shape().concatenate(1))
45  return features, labels
46
47
48def _logistic_regression_model_fn(features, labels, mode):
49  _ = mode
50  logits = layers.linear(
51      features,
52      1,
53      weights_initializer=init_ops.zeros_initializer(),
54      # Intentionally uses really awful initial values so that
55      # AUC/precision/recall/etc will change meaningfully even on a toy dataset.
56      biases_initializer=init_ops.constant_initializer(-10.0))
57  predictions = math_ops.sigmoid(logits)
58  loss = losses.sigmoid_cross_entropy(labels, logits)
59  train_op = optimizers.optimize_loss(
60      loss,
61      training_util.get_global_step(),
62      optimizer='Adagrad',
63      learning_rate=0.1)
64  return predictions, loss, train_op
65
66
67class LogisticRegressorTest(test.TestCase):
68
69  def test_fit_and_evaluate_metrics(self):
70    """Tests basic fit and evaluate, and checks the evaluation metrics."""
71    regressor = logistic_regressor.LogisticRegressor(
72        model_fn=_logistic_regression_model_fn)
73
74    # Get some (intentionally horrible) baseline metrics.
75    regressor.fit(input_fn=_iris_data_input_fn, steps=1)
76    eval_metrics = regressor.evaluate(input_fn=_iris_data_input_fn, steps=1)
77    self.assertNear(
78        0.0, eval_metrics[metric_key.MetricKey.PREDICTION_MEAN], err=1e-3)
79    self.assertNear(
80        0.5, eval_metrics[metric_key.MetricKey.LABEL_MEAN], err=1e-6)
81    self.assertNear(
82        0.5, eval_metrics[metric_key.MetricKey.ACCURACY_BASELINE], err=1e-6)
83    self.assertNear(0.5, eval_metrics[metric_key.MetricKey.AUC], err=1e-6)
84    self.assertNear(
85        0.5, eval_metrics[metric_key.MetricKey.ACCURACY_MEAN % 0.5], err=1e-6)
86    self.assertNear(
87        0.0, eval_metrics[metric_key.MetricKey.PRECISION_MEAN % 0.5], err=1e-6)
88    self.assertNear(
89        0.0, eval_metrics[metric_key.MetricKey.RECALL_MEAN % 0.5], err=1e-6)
90
91    # Train for more steps and check the metrics again.
92    regressor.fit(input_fn=_iris_data_input_fn, steps=100)
93    eval_metrics = regressor.evaluate(input_fn=_iris_data_input_fn, steps=1)
94    # Mean prediction moves from ~0.0 to ~0.5 as we stop predicting all 0's.
95    self.assertNear(
96        0.5, eval_metrics[metric_key.MetricKey.PREDICTION_MEAN], err=1e-2)
97    # Label mean and baseline both remain the same at 0.5.
98    self.assertNear(
99        0.5, eval_metrics[metric_key.MetricKey.LABEL_MEAN], err=1e-6)
100    self.assertNear(
101        0.5, eval_metrics[metric_key.MetricKey.ACCURACY_BASELINE], err=1e-6)
102    # AUC improves from 0.5 to 1.0.
103    self.assertNear(1.0, eval_metrics[metric_key.MetricKey.AUC], err=1e-6)
104    # Accuracy improves from 0.5 to >0.9.
105    self.assertTrue(
106        eval_metrics[metric_key.MetricKey.ACCURACY_MEAN % 0.5] > 0.9)
107    # Precision improves from 0.0 to 1.0.
108    self.assertNear(
109        1.0, eval_metrics[metric_key.MetricKey.PRECISION_MEAN % 0.5], err=1e-6)
110    # Recall improves from 0.0 to >0.9.
111    self.assertTrue(eval_metrics[metric_key.MetricKey.RECALL_MEAN % 0.5] > 0.9)
112
113
114if __name__ == '__main__':
115  test.main()
116