• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `DataFeeder`."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os.path
22import numpy as np
23import six
24from six.moves import xrange  # pylint: disable=redefined-builtin
25
26# pylint: disable=wildcard-import
27from tensorflow.contrib.learn.python.learn.learn_io import *
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.lib.io import file_io
31from tensorflow.python.platform import test
32
33# pylint: enable=wildcard-import
34
35
36class DataFeederTest(test.TestCase):
37  # pylint: disable=undefined-variable
38  """Tests for `DataFeeder`."""
39
40  def setUp(self):
41    self._base_dir = os.path.join(self.get_temp_dir(), 'base_dir')
42    file_io.create_dir(self._base_dir)
43
44  def tearDown(self):
45    file_io.delete_recursively(self._base_dir)
46
47  def _wrap_dict(self, data, prepend=''):
48    return {prepend + '1': data, prepend + '2': data}
49
50  def _assert_raises(self, input_data):
51    with self.assertRaisesRegexp(TypeError, 'annot convert'):
52      data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1)
53
54  def _assert_dtype(self, expected_np_dtype, expected_tf_dtype, input_data):
55    feeder = data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1)
56    if isinstance(input_data, dict):
57      for v in list(feeder.input_dtype.values()):
58        self.assertEqual(expected_np_dtype, v)
59    else:
60      self.assertEqual(expected_np_dtype, feeder.input_dtype)
61    with ops.Graph().as_default() as g, self.session(g):
62      inp, _ = feeder.input_builder()
63      if isinstance(inp, dict):
64        for v in list(inp.values()):
65          self.assertEqual(expected_tf_dtype, v.dtype)
66      else:
67        self.assertEqual(expected_tf_dtype, inp.dtype)
68
69  def test_input_int8(self):
70    data = np.matrix([[1, 2], [3, 4]], dtype=np.int8)
71    self._assert_dtype(np.int8, dtypes.int8, data)
72    self._assert_dtype(np.int8, dtypes.int8, self._wrap_dict(data))
73
74  def test_input_int16(self):
75    data = np.matrix([[1, 2], [3, 4]], dtype=np.int16)
76    self._assert_dtype(np.int16, dtypes.int16, data)
77    self._assert_dtype(np.int16, dtypes.int16, self._wrap_dict(data))
78
79  def test_input_int32(self):
80    data = np.matrix([[1, 2], [3, 4]], dtype=np.int32)
81    self._assert_dtype(np.int32, dtypes.int32, data)
82    self._assert_dtype(np.int32, dtypes.int32, self._wrap_dict(data))
83
84  def test_input_int64(self):
85    data = np.matrix([[1, 2], [3, 4]], dtype=np.int64)
86    self._assert_dtype(np.int64, dtypes.int64, data)
87    self._assert_dtype(np.int64, dtypes.int64, self._wrap_dict(data))
88
89  def test_input_uint32(self):
90    data = np.matrix([[1, 2], [3, 4]], dtype=np.uint32)
91    self._assert_dtype(np.uint32, dtypes.uint32, data)
92    self._assert_dtype(np.uint32, dtypes.uint32, self._wrap_dict(data))
93
94  def test_input_uint64(self):
95    data = np.matrix([[1, 2], [3, 4]], dtype=np.uint64)
96    self._assert_dtype(np.uint64, dtypes.uint64, data)
97    self._assert_dtype(np.uint64, dtypes.uint64, self._wrap_dict(data))
98
99  def test_input_uint8(self):
100    data = np.matrix([[1, 2], [3, 4]], dtype=np.uint8)
101    self._assert_dtype(np.uint8, dtypes.uint8, data)
102    self._assert_dtype(np.uint8, dtypes.uint8, self._wrap_dict(data))
103
104  def test_input_uint16(self):
105    data = np.matrix([[1, 2], [3, 4]], dtype=np.uint16)
106    self._assert_dtype(np.uint16, dtypes.uint16, data)
107    self._assert_dtype(np.uint16, dtypes.uint16, self._wrap_dict(data))
108
109  def test_input_float16(self):
110    data = np.matrix([[1, 2], [3, 4]], dtype=np.float16)
111    self._assert_dtype(np.float16, dtypes.float16, data)
112    self._assert_dtype(np.float16, dtypes.float16, self._wrap_dict(data))
113
114  def test_input_float32(self):
115    data = np.matrix([[1, 2], [3, 4]], dtype=np.float32)
116    self._assert_dtype(np.float32, dtypes.float32, data)
117    self._assert_dtype(np.float32, dtypes.float32, self._wrap_dict(data))
118
119  def test_input_float64(self):
120    data = np.matrix([[1, 2], [3, 4]], dtype=np.float64)
121    self._assert_dtype(np.float64, dtypes.float64, data)
122    self._assert_dtype(np.float64, dtypes.float64, self._wrap_dict(data))
123
124  def test_input_bool(self):
125    data = np.array([[False for _ in xrange(2)] for _ in xrange(2)])
126    self._assert_dtype(np.bool, dtypes.bool, data)
127    self._assert_dtype(np.bool, dtypes.bool, self._wrap_dict(data))
128
129  def test_input_string(self):
130    input_data = np.array([['str%d' % i for i in xrange(2)] for _ in xrange(2)])
131    self._assert_dtype(input_data.dtype, dtypes.string, input_data)
132    self._assert_dtype(input_data.dtype, dtypes.string,
133                       self._wrap_dict(input_data))
134
135  def _assertAllClose(self, src, dest, src_key_of=None, src_prop=None):
136
137    def func(x):
138      val = getattr(x, src_prop) if src_prop else x
139      return val if src_key_of is None else src_key_of[val]
140
141    if isinstance(src, dict):
142      for k in list(src.keys()):
143        self.assertAllClose(func(src[k]), dest)
144    else:
145      self.assertAllClose(func(src), dest)
146
147  def test_unsupervised(self):
148
149    def func(feeder):
150      with self.cached_session():
151        inp, _ = feeder.input_builder()
152        feed_dict_fn = feeder.get_feed_dict_fn()
153        feed_dict = feed_dict_fn()
154        self._assertAllClose(inp, [[1, 2]], feed_dict, 'name')
155
156    data = np.matrix([[1, 2], [2, 3], [3, 4]])
157    func(data_feeder.DataFeeder(data, None, n_classes=0, batch_size=1))
158    func(
159        data_feeder.DataFeeder(
160            self._wrap_dict(data), None, n_classes=0, batch_size=1))
161
162  def test_data_feeder_regression(self):
163
164    def func(df):
165      inp, out = df.input_builder()
166      feed_dict_fn = df.get_feed_dict_fn()
167      feed_dict = feed_dict_fn()
168      self._assertAllClose(inp, [[3, 4], [1, 2]], feed_dict, 'name')
169      self._assertAllClose(out, [2, 1], feed_dict, 'name')
170
171    x = np.matrix([[1, 2], [3, 4]])
172    y = np.array([1, 2])
173    func(data_feeder.DataFeeder(x, y, n_classes=0, batch_size=3))
174    func(
175        data_feeder.DataFeeder(
176            self._wrap_dict(x, 'in'),
177            self._wrap_dict(y, 'out'),
178            n_classes=self._wrap_dict(0, 'out'),
179            batch_size=3))
180
181  def test_epoch(self):
182
183    def func(feeder):
184      with self.cached_session():
185        feeder.input_builder()
186        epoch = feeder.make_epoch_variable()
187        feed_dict_fn = feeder.get_feed_dict_fn()
188        # First input
189        feed_dict = feed_dict_fn()
190        self.assertAllClose(feed_dict[epoch.name], [0])
191        # Second input
192        feed_dict = feed_dict_fn()
193        self.assertAllClose(feed_dict[epoch.name], [0])
194        # Third input
195        feed_dict = feed_dict_fn()
196        self.assertAllClose(feed_dict[epoch.name], [0])
197        # Back to the first input again, so new epoch.
198        feed_dict = feed_dict_fn()
199        self.assertAllClose(feed_dict[epoch.name], [1])
200
201    data = np.matrix([[1, 2], [2, 3], [3, 4]])
202    labels = np.array([0, 0, 1])
203    func(data_feeder.DataFeeder(data, labels, n_classes=0, batch_size=1))
204    func(
205        data_feeder.DataFeeder(
206            self._wrap_dict(data, 'in'),
207            self._wrap_dict(labels, 'out'),
208            n_classes=self._wrap_dict(0, 'out'),
209            batch_size=1))
210
211  def test_data_feeder_multioutput_regression(self):
212
213    def func(df):
214      inp, out = df.input_builder()
215      feed_dict_fn = df.get_feed_dict_fn()
216      feed_dict = feed_dict_fn()
217      self._assertAllClose(inp, [[3, 4], [1, 2]], feed_dict, 'name')
218      self._assertAllClose(out, [[3, 4], [1, 2]], feed_dict, 'name')
219
220    x = np.matrix([[1, 2], [3, 4]])
221    y = np.array([[1, 2], [3, 4]])
222    func(data_feeder.DataFeeder(x, y, n_classes=0, batch_size=2))
223    func(
224        data_feeder.DataFeeder(
225            self._wrap_dict(x, 'in'),
226            self._wrap_dict(y, 'out'),
227            n_classes=self._wrap_dict(0, 'out'),
228            batch_size=2))
229
230  def test_data_feeder_multioutput_classification(self):
231
232    def func(df):
233      inp, out = df.input_builder()
234      feed_dict_fn = df.get_feed_dict_fn()
235      feed_dict = feed_dict_fn()
236      self._assertAllClose(inp, [[3, 4], [1, 2]], feed_dict, 'name')
237      self._assertAllClose(
238          out, [[[0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]],
239                [[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]], feed_dict,
240          'name')
241
242    x = np.matrix([[1, 2], [3, 4]])
243    y = np.array([[0, 1, 2], [2, 3, 4]])
244    func(data_feeder.DataFeeder(x, y, n_classes=5, batch_size=2))
245    func(
246        data_feeder.DataFeeder(
247            self._wrap_dict(x, 'in'),
248            self._wrap_dict(y, 'out'),
249            n_classes=self._wrap_dict(5, 'out'),
250            batch_size=2))
251
252  def test_streaming_data_feeder(self):
253
254    def func(df):
255      inp, out = df.input_builder()
256      feed_dict_fn = df.get_feed_dict_fn()
257      feed_dict = feed_dict_fn()
258      self._assertAllClose(inp, [[[1, 2]], [[3, 4]]], feed_dict, 'name')
259      self._assertAllClose(out, [[[1], [2]], [[2], [2]]], feed_dict, 'name')
260
261    def x_iter(wrap_dict=False):
262      yield np.array([[1, 2]]) if not wrap_dict else self._wrap_dict(
263          np.array([[1, 2]]), 'in')
264      yield np.array([[3, 4]]) if not wrap_dict else self._wrap_dict(
265          np.array([[3, 4]]), 'in')
266
267    def y_iter(wrap_dict=False):
268      yield np.array([[1], [2]]) if not wrap_dict else self._wrap_dict(
269          np.array([[1], [2]]), 'out')
270      yield np.array([[2], [2]]) if not wrap_dict else self._wrap_dict(
271          np.array([[2], [2]]), 'out')
272
273    func(
274        data_feeder.StreamingDataFeeder(
275            x_iter(), y_iter(), n_classes=0, batch_size=2))
276    func(
277        data_feeder.StreamingDataFeeder(
278            x_iter(True),
279            y_iter(True),
280            n_classes=self._wrap_dict(0, 'out'),
281            batch_size=2))
282    # Test non-full batches.
283    func(
284        data_feeder.StreamingDataFeeder(
285            x_iter(), y_iter(), n_classes=0, batch_size=10))
286    func(
287        data_feeder.StreamingDataFeeder(
288            x_iter(True),
289            y_iter(True),
290            n_classes=self._wrap_dict(0, 'out'),
291            batch_size=10))
292
293  def test_dask_data_feeder(self):
294    if HAS_PANDAS and HAS_DASK:
295      x = pd.DataFrame(
296          dict(
297              a=np.array([.1, .3, .4, .6, .2, .1, .6]),
298              b=np.array([.7, .8, .1, .2, .5, .3, .9])))
299      x = dd.from_pandas(x, npartitions=2)
300      y = pd.DataFrame(dict(labels=np.array([1, 0, 2, 1, 0, 1, 2])))
301      y = dd.from_pandas(y, npartitions=2)
302      # TODO(ipolosukhin): Remove or restore this.
303      # x = extract_dask_data(x)
304      # y = extract_dask_labels(y)
305      df = data_feeder.DaskDataFeeder(x, y, n_classes=2, batch_size=2)
306      inp, out = df.input_builder()
307      feed_dict_fn = df.get_feed_dict_fn()
308      feed_dict = feed_dict_fn()
309      self.assertAllClose(feed_dict[inp.name], [[0.40000001, 0.1],
310                                                [0.60000002, 0.2]])
311      self.assertAllClose(feed_dict[out.name], [[0., 0., 1.], [0., 1., 0.]])
312
313  # TODO(rohanj): Fix this test by fixing data_feeder. Currently, h5py doesn't
314  # support permutation based indexing lookups (More documentation at
315  # http://docs.h5py.org/en/latest/high/dataset.html#fancy-indexing)
316  def DISABLED_test_hdf5_data_feeder(self):
317
318    def func(df):
319      inp, out = df.input_builder()
320      feed_dict_fn = df.get_feed_dict_fn()
321      feed_dict = feed_dict_fn()
322      self._assertAllClose(inp, [[3, 4], [1, 2]], feed_dict, 'name')
323      self.assertAllClose(out, [2, 1], feed_dict, 'name')
324
325    try:
326      import h5py  # pylint: disable=g-import-not-at-top
327      x = np.matrix([[1, 2], [3, 4]])
328      y = np.array([1, 2])
329      file_path = os.path.join(self._base_dir, 'test_hdf5.h5')
330      h5f = h5py.File(file_path, 'w')
331      h5f.create_dataset('x', data=x)
332      h5f.create_dataset('y', data=y)
333      h5f.close()
334      h5f = h5py.File(file_path, 'r')
335      x = h5f['x']
336      y = h5f['y']
337      func(data_feeder.DataFeeder(x, y, n_classes=0, batch_size=3))
338      func(
339          data_feeder.DataFeeder(
340              self._wrap_dict(x, 'in'),
341              self._wrap_dict(y, 'out'),
342              n_classes=self._wrap_dict(0, 'out'),
343              batch_size=3))
344    except ImportError:
345      print("Skipped test for hdf5 since it's not installed.")
346
347
348class SetupPredictDataFeederTest(DataFeederTest):
349  """Tests for `DataFeeder.setup_predict_data_feeder`."""
350
351  def test_iterable_data(self):
352    # pylint: disable=undefined-variable
353
354    def func(df):
355      self._assertAllClose(six.next(df), [[1, 2], [3, 4]])
356      self._assertAllClose(six.next(df), [[5, 6]])
357
358    data = [[1, 2], [3, 4], [5, 6]]
359    x = iter(data)
360    x_dict = iter([self._wrap_dict(v) for v in iter(data)])
361    func(data_feeder.setup_predict_data_feeder(x, batch_size=2))
362    func(data_feeder.setup_predict_data_feeder(x_dict, batch_size=2))
363
364
365if __name__ == '__main__':
366  test.main()
367