• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""tf.learn IO operation tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import random
22
23# pylint: disable=wildcard-import
24from tensorflow.contrib.learn.python import learn
25from tensorflow.contrib.learn.python.learn import datasets
26from tensorflow.contrib.learn.python.learn.estimators._sklearn import accuracy_score
27from tensorflow.contrib.learn.python.learn.learn_io import *
28from tensorflow.python.platform import test
29
30# pylint: enable=wildcard-import
31
32
33class IOTest(test.TestCase):
34  # pylint: disable=undefined-variable
35  """tf.learn IO operation tests."""
36
37  def test_pandas_dataframe(self):
38    if HAS_PANDAS:
39      import pandas as pd  # pylint: disable=g-import-not-at-top
40      random.seed(42)
41      iris = datasets.load_iris()
42      data = pd.DataFrame(iris.data)
43      labels = pd.DataFrame(iris.target)
44      classifier = learn.LinearClassifier(
45          feature_columns=learn.infer_real_valued_columns_from_input(data),
46          n_classes=3)
47      classifier.fit(data, labels, steps=100)
48      score = accuracy_score(labels[0], list(classifier.predict_classes(data)))
49      self.assertGreater(score, 0.5, "Failed with score = {0}".format(score))
50    else:
51      print("No pandas installed. pandas-related tests are skipped.")
52
53  def test_pandas_series(self):
54    if HAS_PANDAS:
55      import pandas as pd  # pylint: disable=g-import-not-at-top
56      random.seed(42)
57      iris = datasets.load_iris()
58      data = pd.DataFrame(iris.data)
59      labels = pd.Series(iris.target)
60      classifier = learn.LinearClassifier(
61          feature_columns=learn.infer_real_valued_columns_from_input(data),
62          n_classes=3)
63      classifier.fit(data, labels, steps=100)
64      score = accuracy_score(labels, list(classifier.predict_classes(data)))
65      self.assertGreater(score, 0.5, "Failed with score = {0}".format(score))
66
67  def test_string_data_formats(self):
68    if HAS_PANDAS:
69      import pandas as pd  # pylint: disable=g-import-not-at-top
70      with self.assertRaises(ValueError):
71        learn.io.extract_pandas_data(pd.DataFrame({"Test": ["A", "B"]}))
72      with self.assertRaises(ValueError):
73        learn.io.extract_pandas_labels(pd.DataFrame({"Test": ["A", "B"]}))
74
75  def test_dask_io(self):
76    if HAS_DASK and HAS_PANDAS:
77      import pandas as pd  # pylint: disable=g-import-not-at-top
78      import dask.dataframe as dd  # pylint: disable=g-import-not-at-top
79      # test dask.dataframe
80      df = pd.DataFrame(
81          dict(
82              a=list("aabbcc"), b=list(range(6))),
83          index=pd.date_range(
84              start="20100101", periods=6))
85      ddf = dd.from_pandas(df, npartitions=3)
86      extracted_ddf = extract_dask_data(ddf)
87      self.assertEqual(
88          extracted_ddf.divisions, (0, 2, 4, 6),
89          "Failed with divisions = {0}".format(extracted_ddf.divisions))
90      self.assertEqual(
91          extracted_ddf.columns.tolist(), ["a", "b"],
92          "Failed with columns = {0}".format(extracted_ddf.columns))
93      # test dask.series
94      labels = ddf["a"]
95      extracted_labels = extract_dask_labels(labels)
96      self.assertEqual(
97          extracted_labels.divisions, (0, 2, 4, 6),
98          "Failed with divisions = {0}".format(extracted_labels.divisions))
99      # labels should only have one column
100      with self.assertRaises(ValueError):
101        extract_dask_labels(ddf)
102    else:
103      print("No dask installed. dask-related tests are skipped.")
104
105  def test_dask_iris_classification(self):
106    if HAS_DASK and HAS_PANDAS:
107      import pandas as pd  # pylint: disable=g-import-not-at-top
108      import dask.dataframe as dd  # pylint: disable=g-import-not-at-top
109      random.seed(42)
110      iris = datasets.load_iris()
111      data = pd.DataFrame(iris.data)
112      data = dd.from_pandas(data, npartitions=2)
113      labels = pd.DataFrame(iris.target)
114      labels = dd.from_pandas(labels, npartitions=2)
115      classifier = learn.LinearClassifier(
116          feature_columns=learn.infer_real_valued_columns_from_input(data),
117          n_classes=3)
118      classifier.fit(data, labels, steps=100)
119      predictions = data.map_partitions(classifier.predict).compute()
120      score = accuracy_score(labels.compute(), predictions)
121      self.assertGreater(score, 0.5, "Failed with score = {0}".format(score))
122
123
124if __name__ == "__main__":
125  test.main()
126