• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Custom optimizer tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import random
22
23import numpy as np
24
25from tensorflow.python.training import training_util
26from tensorflow.contrib.learn.python import learn
27from tensorflow.contrib.learn.python.learn import datasets
28from tensorflow.contrib.learn.python.learn import metric_spec
29from tensorflow.contrib.learn.python.learn.estimators import estimator as estimator_lib
30from tensorflow.contrib.learn.python.learn.estimators._sklearn import accuracy_score
31from tensorflow.contrib.learn.python.learn.estimators._sklearn import train_test_split
32from tensorflow.python.framework import constant_op
33from tensorflow.python.ops import string_ops
34from tensorflow.python.ops import variables as variables_lib
35from tensorflow.python.platform import test
36from tensorflow.python.training import momentum as momentum_lib
37
38
39class FeatureEngineeringFunctionTest(test.TestCase):
40  """Tests feature_engineering_fn."""
41
42  def testFeatureEngineeringFn(self):
43
44    def input_fn():
45      return {
46          "x": constant_op.constant([1.])
47      }, {
48          "y": constant_op.constant([11.])
49      }
50
51    def feature_engineering_fn(features, labels):
52      _, _ = features, labels
53      return {
54          "transformed_x": constant_op.constant([9.])
55      }, {
56          "transformed_y": constant_op.constant([99.])
57      }
58
59    def model_fn(features, labels):
60      # dummy variable:
61      _ = variables_lib.Variable([0.])
62      _ = labels
63      predictions = features["transformed_x"]
64      loss = constant_op.constant([2.])
65      update_global_step = training_util.get_global_step().assign_add(1)
66      return predictions, loss, update_global_step
67
68    estimator = estimator_lib.Estimator(
69        model_fn=model_fn, feature_engineering_fn=feature_engineering_fn)
70    estimator.fit(input_fn=input_fn, steps=1)
71    prediction = next(estimator.predict(input_fn=input_fn, as_iterable=True))
72    # predictions = transformed_x (9)
73    self.assertEqual(9., prediction)
74    metrics = estimator.evaluate(
75        input_fn=input_fn,
76        steps=1,
77        metrics={
78            "label": metric_spec.MetricSpec(lambda predictions, labels: labels)
79        })
80    # labels = transformed_y (99)
81    self.assertEqual(99., metrics["label"])
82
83  def testFeatureEngineeringFnWithSameName(self):
84
85    def input_fn():
86      return {
87          "x": constant_op.constant(["9."])
88      }, {
89          "y": constant_op.constant(["99."])
90      }
91
92    def feature_engineering_fn(features, labels):
93      # Github #12205: raise a TypeError if called twice.
94      _ = string_ops.string_split(features["x"])
95      features["x"] = constant_op.constant([9.])
96      labels["y"] = constant_op.constant([99.])
97      return features, labels
98
99    def model_fn(features, labels):
100      # dummy variable:
101      _ = variables_lib.Variable([0.])
102      _ = labels
103      predictions = features["x"]
104      loss = constant_op.constant([2.])
105      update_global_step = training_util.get_global_step().assign_add(1)
106      return predictions, loss, update_global_step
107
108    estimator = estimator_lib.Estimator(
109        model_fn=model_fn, feature_engineering_fn=feature_engineering_fn)
110    estimator.fit(input_fn=input_fn, steps=1)
111    prediction = next(estimator.predict(input_fn=input_fn, as_iterable=True))
112    # predictions = transformed_x (9)
113    self.assertEqual(9., prediction)
114    metrics = estimator.evaluate(
115        input_fn=input_fn,
116        steps=1,
117        metrics={
118            "label": metric_spec.MetricSpec(lambda predictions, labels: labels)
119        })
120    # labels = transformed_y (99)
121    self.assertEqual(99., metrics["label"])
122
123  def testNoneFeatureEngineeringFn(self):
124
125    def input_fn():
126      return {
127          "x": constant_op.constant([1.])
128      }, {
129          "y": constant_op.constant([11.])
130      }
131
132    def feature_engineering_fn(features, labels):
133      _, _ = features, labels
134      return {
135          "x": constant_op.constant([9.])
136      }, {
137          "y": constant_op.constant([99.])
138      }
139
140    def model_fn(features, labels):
141      # dummy variable:
142      _ = variables_lib.Variable([0.])
143      _ = labels
144      predictions = features["x"]
145      loss = constant_op.constant([2.])
146      update_global_step = training_util.get_global_step().assign_add(1)
147      return predictions, loss, update_global_step
148
149    estimator_with_fe_fn = estimator_lib.Estimator(
150        model_fn=model_fn, feature_engineering_fn=feature_engineering_fn)
151    estimator_with_fe_fn.fit(input_fn=input_fn, steps=1)
152    estimator_without_fe_fn = estimator_lib.Estimator(model_fn=model_fn)
153    estimator_without_fe_fn.fit(input_fn=input_fn, steps=1)
154
155    # predictions = x
156    prediction_with_fe_fn = next(
157        estimator_with_fe_fn.predict(input_fn=input_fn, as_iterable=True))
158    self.assertEqual(9., prediction_with_fe_fn)
159    prediction_without_fe_fn = next(
160        estimator_without_fe_fn.predict(input_fn=input_fn, as_iterable=True))
161    self.assertEqual(1., prediction_without_fe_fn)
162
163
164class CustomOptimizer(test.TestCase):
165  """Custom optimizer tests."""
166
167  def testIrisMomentum(self):
168    random.seed(42)
169
170    iris = datasets.load_iris()
171    x_train, x_test, y_train, y_test = train_test_split(
172        iris.data, iris.target, test_size=0.2, random_state=42)
173
174    def custom_optimizer():
175      return momentum_lib.MomentumOptimizer(learning_rate=0.01, momentum=0.9)
176
177    classifier = learn.DNNClassifier(
178        hidden_units=[10, 20, 10],
179        feature_columns=learn.infer_real_valued_columns_from_input(x_train),
180        n_classes=3,
181        optimizer=custom_optimizer,
182        config=learn.RunConfig(tf_random_seed=1))
183    classifier.fit(x_train, y_train, steps=400)
184    predictions = np.array(list(classifier.predict_classes(x_test)))
185    score = accuracy_score(y_test, predictions)
186
187    self.assertGreater(score, 0.65, "Failed with score = {0}".format(score))
188
189
190if __name__ == "__main__":
191  test.main()
192