• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for pandas_io."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.contrib.learn.python.learn.learn_io import pandas_io
24from tensorflow.python.framework import errors
25from tensorflow.python.platform import test
26from tensorflow.python.training import coordinator
27from tensorflow.python.training import queue_runner_impl
28
29# pylint: disable=g-import-not-at-top
30try:
31  import pandas as pd
32  HAS_PANDAS = True
33except ImportError:
34  HAS_PANDAS = False
35
36
37class PandasIoTest(test.TestCase):
38
39  def makeTestDataFrame(self):
40    index = np.arange(100, 104)
41    a = np.arange(4)
42    b = np.arange(32, 36)
43    x = pd.DataFrame({'a': a, 'b': b}, index=index)
44    y = pd.Series(np.arange(-32, -28), index=index)
45    return x, y
46
47  def callInputFnOnce(self, input_fn, session):
48    results = input_fn()
49    coord = coordinator.Coordinator()
50    threads = queue_runner_impl.start_queue_runners(session, coord=coord)
51    result_values = session.run(results)
52    coord.request_stop()
53    coord.join(threads)
54    return result_values
55
56  def testPandasInputFn_IndexMismatch(self):
57    if not HAS_PANDAS:
58      return
59    x, _ = self.makeTestDataFrame()
60    y_noindex = pd.Series(np.arange(-32, -28))
61    with self.assertRaises(ValueError):
62      pandas_io.pandas_input_fn(
63          x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)
64
65  def testPandasInputFn_ProducesExpectedOutputs(self):
66    if not HAS_PANDAS:
67      return
68    with self.cached_session() as session:
69      x, y = self.makeTestDataFrame()
70      input_fn = pandas_io.pandas_input_fn(
71          x, y, batch_size=2, shuffle=False, num_epochs=1)
72
73      features, target = self.callInputFnOnce(input_fn, session)
74
75      self.assertAllEqual(features['a'], [0, 1])
76      self.assertAllEqual(features['b'], [32, 33])
77      self.assertAllEqual(target, [-32, -31])
78
79  def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
80    if not HAS_PANDAS:
81      return
82    with self.cached_session() as session:
83      index = np.arange(100, 102)
84      a = np.arange(2)
85      b = np.arange(32, 34)
86      x = pd.DataFrame({'a': a, 'b': b}, index=index)
87      y = pd.Series(np.arange(-32, -30), index=index)
88      input_fn = pandas_io.pandas_input_fn(
89          x, y, batch_size=128, shuffle=False, num_epochs=2)
90
91      results = input_fn()
92
93      coord = coordinator.Coordinator()
94      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
95
96      features, target = session.run(results)
97      self.assertAllEqual(features['a'], [0, 1, 0, 1])
98      self.assertAllEqual(features['b'], [32, 33, 32, 33])
99      self.assertAllEqual(target, [-32, -31, -32, -31])
100
101      with self.assertRaises(errors.OutOfRangeError):
102        session.run(results)
103
104      coord.request_stop()
105      coord.join(threads)
106
107  def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self):
108    if not HAS_PANDAS:
109      return
110    with self.cached_session() as session:
111      index = np.arange(100, 105)
112      a = np.arange(5)
113      b = np.arange(32, 37)
114      x = pd.DataFrame({'a': a, 'b': b}, index=index)
115      y = pd.Series(np.arange(-32, -27), index=index)
116
117      input_fn = pandas_io.pandas_input_fn(
118          x, y, batch_size=2, shuffle=False, num_epochs=1)
119
120      results = input_fn()
121
122      coord = coordinator.Coordinator()
123      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
124
125      features, target = session.run(results)
126      self.assertAllEqual(features['a'], [0, 1])
127      self.assertAllEqual(features['b'], [32, 33])
128      self.assertAllEqual(target, [-32, -31])
129
130      features, target = session.run(results)
131      self.assertAllEqual(features['a'], [2, 3])
132      self.assertAllEqual(features['b'], [34, 35])
133      self.assertAllEqual(target, [-30, -29])
134
135      features, target = session.run(results)
136      self.assertAllEqual(features['a'], [4])
137      self.assertAllEqual(features['b'], [36])
138      self.assertAllEqual(target, [-28])
139
140      with self.assertRaises(errors.OutOfRangeError):
141        session.run(results)
142
143      coord.request_stop()
144      coord.join(threads)
145
146  def testPandasInputFn_OnlyX(self):
147    if not HAS_PANDAS:
148      return
149    with self.cached_session() as session:
150      x, _ = self.makeTestDataFrame()
151      input_fn = pandas_io.pandas_input_fn(
152          x, y=None, batch_size=2, shuffle=False, num_epochs=1)
153
154      features = self.callInputFnOnce(input_fn, session)
155
156      self.assertAllEqual(features['a'], [0, 1])
157      self.assertAllEqual(features['b'], [32, 33])
158
159  def testPandasInputFn_ExcludesIndex(self):
160    if not HAS_PANDAS:
161      return
162    with self.cached_session() as session:
163      x, y = self.makeTestDataFrame()
164      input_fn = pandas_io.pandas_input_fn(
165          x, y, batch_size=2, shuffle=False, num_epochs=1)
166
167      features, _ = self.callInputFnOnce(input_fn, session)
168
169      self.assertFalse('index' in features)
170
171  def assertInputsCallableNTimes(self, input_fn, session, n):
172    inputs = input_fn()
173    coord = coordinator.Coordinator()
174    threads = queue_runner_impl.start_queue_runners(session, coord=coord)
175    for _ in range(n):
176      session.run(inputs)
177    with self.assertRaises(errors.OutOfRangeError):
178      session.run(inputs)
179    coord.request_stop()
180    coord.join(threads)
181
182  def testPandasInputFn_RespectsEpoch_NoShuffle(self):
183    if not HAS_PANDAS:
184      return
185    with self.cached_session() as session:
186      x, y = self.makeTestDataFrame()
187      input_fn = pandas_io.pandas_input_fn(
188          x, y, batch_size=4, shuffle=False, num_epochs=1)
189
190      self.assertInputsCallableNTimes(input_fn, session, 1)
191
192  def testPandasInputFn_RespectsEpoch_WithShuffle(self):
193    if not HAS_PANDAS:
194      return
195    with self.cached_session() as session:
196      x, y = self.makeTestDataFrame()
197      input_fn = pandas_io.pandas_input_fn(
198          x, y, batch_size=4, shuffle=True, num_epochs=1)
199
200      self.assertInputsCallableNTimes(input_fn, session, 1)
201
202  def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self):
203    if not HAS_PANDAS:
204      return
205    with self.cached_session() as session:
206      x, y = self.makeTestDataFrame()
207      input_fn = pandas_io.pandas_input_fn(
208          x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2)
209
210      self.assertInputsCallableNTimes(input_fn, session, 4)
211
212  def testPandasInputFn_RespectsEpochUnevenBatches(self):
213    if not HAS_PANDAS:
214      return
215    x, y = self.makeTestDataFrame()
216    with self.cached_session() as session:
217      input_fn = pandas_io.pandas_input_fn(
218          x, y, batch_size=3, shuffle=False, num_epochs=1)
219
220      # Before the last batch, only one element of the epoch should remain.
221      self.assertInputsCallableNTimes(input_fn, session, 2)
222
223  def testPandasInputFn_Idempotent(self):
224    if not HAS_PANDAS:
225      return
226    x, y = self.makeTestDataFrame()
227    for _ in range(2):
228      pandas_io.pandas_input_fn(
229          x, y, batch_size=2, shuffle=False, num_epochs=1)()
230    for _ in range(2):
231      pandas_io.pandas_input_fn(
232          x, y, batch_size=2, shuffle=True, num_epochs=1)()
233
234
235if __name__ == '__main__':
236  test.main()
237