1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `DataFeeder`.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os.path 22import numpy as np 23import six 24from six.moves import xrange # pylint: disable=redefined-builtin 25 26# pylint: disable=wildcard-import 27from tensorflow.contrib.learn.python.learn.learn_io import * 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.lib.io import file_io 31from tensorflow.python.platform import test 32 33# pylint: enable=wildcard-import 34 35 36class DataFeederTest(test.TestCase): 37 # pylint: disable=undefined-variable 38 """Tests for `DataFeeder`.""" 39 40 def setUp(self): 41 self._base_dir = os.path.join(self.get_temp_dir(), 'base_dir') 42 file_io.create_dir(self._base_dir) 43 44 def tearDown(self): 45 file_io.delete_recursively(self._base_dir) 46 47 def _wrap_dict(self, data, prepend=''): 48 return {prepend + '1': data, prepend + '2': data} 49 50 def _assert_raises(self, input_data): 51 with self.assertRaisesRegexp(TypeError, 'annot convert'): 52 data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1) 53 54 def _assert_dtype(self, expected_np_dtype, expected_tf_dtype, input_data): 55 feeder = data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1) 56 if isinstance(input_data, dict): 57 for v in list(feeder.input_dtype.values()): 58 self.assertEqual(expected_np_dtype, v) 59 else: 60 self.assertEqual(expected_np_dtype, feeder.input_dtype) 61 with ops.Graph().as_default() as g, self.session(g): 62 inp, _ = feeder.input_builder() 63 if isinstance(inp, dict): 64 for v in list(inp.values()): 65 self.assertEqual(expected_tf_dtype, v.dtype) 66 else: 67 self.assertEqual(expected_tf_dtype, inp.dtype) 68 69 def test_input_int8(self): 70 data = np.matrix([[1, 2], [3, 4]], dtype=np.int8) 71 self._assert_dtype(np.int8, dtypes.int8, data) 72 self._assert_dtype(np.int8, dtypes.int8, self._wrap_dict(data)) 73 74 def test_input_int16(self): 75 data = np.matrix([[1, 2], [3, 4]], dtype=np.int16) 76 self._assert_dtype(np.int16, dtypes.int16, data) 77 self._assert_dtype(np.int16, dtypes.int16, self._wrap_dict(data)) 78 79 def test_input_int32(self): 80 data = np.matrix([[1, 2], [3, 4]], dtype=np.int32) 81 self._assert_dtype(np.int32, dtypes.int32, data) 82 self._assert_dtype(np.int32, dtypes.int32, self._wrap_dict(data)) 83 84 def test_input_int64(self): 85 data = np.matrix([[1, 2], [3, 4]], dtype=np.int64) 86 self._assert_dtype(np.int64, dtypes.int64, data) 87 self._assert_dtype(np.int64, dtypes.int64, self._wrap_dict(data)) 88 89 def test_input_uint32(self): 90 data = np.matrix([[1, 2], [3, 4]], dtype=np.uint32) 91 self._assert_dtype(np.uint32, dtypes.uint32, data) 92 self._assert_dtype(np.uint32, dtypes.uint32, self._wrap_dict(data)) 93 94 def test_input_uint64(self): 95 data = np.matrix([[1, 2], [3, 4]], dtype=np.uint64) 96 self._assert_dtype(np.uint64, dtypes.uint64, data) 97 self._assert_dtype(np.uint64, dtypes.uint64, self._wrap_dict(data)) 98 99 def test_input_uint8(self): 100 data = np.matrix([[1, 2], [3, 4]], dtype=np.uint8) 101 self._assert_dtype(np.uint8, dtypes.uint8, data) 102 self._assert_dtype(np.uint8, dtypes.uint8, self._wrap_dict(data)) 103 104 def test_input_uint16(self): 105 data = np.matrix([[1, 2], [3, 4]], dtype=np.uint16) 106 self._assert_dtype(np.uint16, dtypes.uint16, data) 107 self._assert_dtype(np.uint16, dtypes.uint16, self._wrap_dict(data)) 108 109 def test_input_float16(self): 110 data = np.matrix([[1, 2], [3, 4]], dtype=np.float16) 111 self._assert_dtype(np.float16, dtypes.float16, data) 112 self._assert_dtype(np.float16, dtypes.float16, self._wrap_dict(data)) 113 114 def test_input_float32(self): 115 data = np.matrix([[1, 2], [3, 4]], dtype=np.float32) 116 self._assert_dtype(np.float32, dtypes.float32, data) 117 self._assert_dtype(np.float32, dtypes.float32, self._wrap_dict(data)) 118 119 def test_input_float64(self): 120 data = np.matrix([[1, 2], [3, 4]], dtype=np.float64) 121 self._assert_dtype(np.float64, dtypes.float64, data) 122 self._assert_dtype(np.float64, dtypes.float64, self._wrap_dict(data)) 123 124 def test_input_bool(self): 125 data = np.array([[False for _ in xrange(2)] for _ in xrange(2)]) 126 self._assert_dtype(np.bool, dtypes.bool, data) 127 self._assert_dtype(np.bool, dtypes.bool, self._wrap_dict(data)) 128 129 def test_input_string(self): 130 input_data = np.array([['str%d' % i for i in xrange(2)] for _ in xrange(2)]) 131 self._assert_dtype(input_data.dtype, dtypes.string, input_data) 132 self._assert_dtype(input_data.dtype, dtypes.string, 133 self._wrap_dict(input_data)) 134 135 def _assertAllClose(self, src, dest, src_key_of=None, src_prop=None): 136 137 def func(x): 138 val = getattr(x, src_prop) if src_prop else x 139 return val if src_key_of is None else src_key_of[val] 140 141 if isinstance(src, dict): 142 for k in list(src.keys()): 143 self.assertAllClose(func(src[k]), dest) 144 else: 145 self.assertAllClose(func(src), dest) 146 147 def test_unsupervised(self): 148 149 def func(feeder): 150 with self.cached_session(): 151 inp, _ = feeder.input_builder() 152 feed_dict_fn = feeder.get_feed_dict_fn() 153 feed_dict = feed_dict_fn() 154 self._assertAllClose(inp, [[1, 2]], feed_dict, 'name') 155 156 data = np.matrix([[1, 2], [2, 3], [3, 4]]) 157 func(data_feeder.DataFeeder(data, None, n_classes=0, batch_size=1)) 158 func( 159 data_feeder.DataFeeder( 160 self._wrap_dict(data), None, n_classes=0, batch_size=1)) 161 162 def test_data_feeder_regression(self): 163 164 def func(df): 165 inp, out = df.input_builder() 166 feed_dict_fn = df.get_feed_dict_fn() 167 feed_dict = feed_dict_fn() 168 self._assertAllClose(inp, [[3, 4], [1, 2]], feed_dict, 'name') 169 self._assertAllClose(out, [2, 1], feed_dict, 'name') 170 171 x = np.matrix([[1, 2], [3, 4]]) 172 y = np.array([1, 2]) 173 func(data_feeder.DataFeeder(x, y, n_classes=0, batch_size=3)) 174 func( 175 data_feeder.DataFeeder( 176 self._wrap_dict(x, 'in'), 177 self._wrap_dict(y, 'out'), 178 n_classes=self._wrap_dict(0, 'out'), 179 batch_size=3)) 180 181 def test_epoch(self): 182 183 def func(feeder): 184 with self.cached_session(): 185 feeder.input_builder() 186 epoch = feeder.make_epoch_variable() 187 feed_dict_fn = feeder.get_feed_dict_fn() 188 # First input 189 feed_dict = feed_dict_fn() 190 self.assertAllClose(feed_dict[epoch.name], [0]) 191 # Second input 192 feed_dict = feed_dict_fn() 193 self.assertAllClose(feed_dict[epoch.name], [0]) 194 # Third input 195 feed_dict = feed_dict_fn() 196 self.assertAllClose(feed_dict[epoch.name], [0]) 197 # Back to the first input again, so new epoch. 198 feed_dict = feed_dict_fn() 199 self.assertAllClose(feed_dict[epoch.name], [1]) 200 201 data = np.matrix([[1, 2], [2, 3], [3, 4]]) 202 labels = np.array([0, 0, 1]) 203 func(data_feeder.DataFeeder(data, labels, n_classes=0, batch_size=1)) 204 func( 205 data_feeder.DataFeeder( 206 self._wrap_dict(data, 'in'), 207 self._wrap_dict(labels, 'out'), 208 n_classes=self._wrap_dict(0, 'out'), 209 batch_size=1)) 210 211 def test_data_feeder_multioutput_regression(self): 212 213 def func(df): 214 inp, out = df.input_builder() 215 feed_dict_fn = df.get_feed_dict_fn() 216 feed_dict = feed_dict_fn() 217 self._assertAllClose(inp, [[3, 4], [1, 2]], feed_dict, 'name') 218 self._assertAllClose(out, [[3, 4], [1, 2]], feed_dict, 'name') 219 220 x = np.matrix([[1, 2], [3, 4]]) 221 y = np.array([[1, 2], [3, 4]]) 222 func(data_feeder.DataFeeder(x, y, n_classes=0, batch_size=2)) 223 func( 224 data_feeder.DataFeeder( 225 self._wrap_dict(x, 'in'), 226 self._wrap_dict(y, 'out'), 227 n_classes=self._wrap_dict(0, 'out'), 228 batch_size=2)) 229 230 def test_data_feeder_multioutput_classification(self): 231 232 def func(df): 233 inp, out = df.input_builder() 234 feed_dict_fn = df.get_feed_dict_fn() 235 feed_dict = feed_dict_fn() 236 self._assertAllClose(inp, [[3, 4], [1, 2]], feed_dict, 'name') 237 self._assertAllClose( 238 out, [[[0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]], 239 [[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]], feed_dict, 240 'name') 241 242 x = np.matrix([[1, 2], [3, 4]]) 243 y = np.array([[0, 1, 2], [2, 3, 4]]) 244 func(data_feeder.DataFeeder(x, y, n_classes=5, batch_size=2)) 245 func( 246 data_feeder.DataFeeder( 247 self._wrap_dict(x, 'in'), 248 self._wrap_dict(y, 'out'), 249 n_classes=self._wrap_dict(5, 'out'), 250 batch_size=2)) 251 252 def test_streaming_data_feeder(self): 253 254 def func(df): 255 inp, out = df.input_builder() 256 feed_dict_fn = df.get_feed_dict_fn() 257 feed_dict = feed_dict_fn() 258 self._assertAllClose(inp, [[[1, 2]], [[3, 4]]], feed_dict, 'name') 259 self._assertAllClose(out, [[[1], [2]], [[2], [2]]], feed_dict, 'name') 260 261 def x_iter(wrap_dict=False): 262 yield np.array([[1, 2]]) if not wrap_dict else self._wrap_dict( 263 np.array([[1, 2]]), 'in') 264 yield np.array([[3, 4]]) if not wrap_dict else self._wrap_dict( 265 np.array([[3, 4]]), 'in') 266 267 def y_iter(wrap_dict=False): 268 yield np.array([[1], [2]]) if not wrap_dict else self._wrap_dict( 269 np.array([[1], [2]]), 'out') 270 yield np.array([[2], [2]]) if not wrap_dict else self._wrap_dict( 271 np.array([[2], [2]]), 'out') 272 273 func( 274 data_feeder.StreamingDataFeeder( 275 x_iter(), y_iter(), n_classes=0, batch_size=2)) 276 func( 277 data_feeder.StreamingDataFeeder( 278 x_iter(True), 279 y_iter(True), 280 n_classes=self._wrap_dict(0, 'out'), 281 batch_size=2)) 282 # Test non-full batches. 283 func( 284 data_feeder.StreamingDataFeeder( 285 x_iter(), y_iter(), n_classes=0, batch_size=10)) 286 func( 287 data_feeder.StreamingDataFeeder( 288 x_iter(True), 289 y_iter(True), 290 n_classes=self._wrap_dict(0, 'out'), 291 batch_size=10)) 292 293 def test_dask_data_feeder(self): 294 if HAS_PANDAS and HAS_DASK: 295 x = pd.DataFrame( 296 dict( 297 a=np.array([.1, .3, .4, .6, .2, .1, .6]), 298 b=np.array([.7, .8, .1, .2, .5, .3, .9]))) 299 x = dd.from_pandas(x, npartitions=2) 300 y = pd.DataFrame(dict(labels=np.array([1, 0, 2, 1, 0, 1, 2]))) 301 y = dd.from_pandas(y, npartitions=2) 302 # TODO(ipolosukhin): Remove or restore this. 303 # x = extract_dask_data(x) 304 # y = extract_dask_labels(y) 305 df = data_feeder.DaskDataFeeder(x, y, n_classes=2, batch_size=2) 306 inp, out = df.input_builder() 307 feed_dict_fn = df.get_feed_dict_fn() 308 feed_dict = feed_dict_fn() 309 self.assertAllClose(feed_dict[inp.name], [[0.40000001, 0.1], 310 [0.60000002, 0.2]]) 311 self.assertAllClose(feed_dict[out.name], [[0., 0., 1.], [0., 1., 0.]]) 312 313 # TODO(rohanj): Fix this test by fixing data_feeder. Currently, h5py doesn't 314 # support permutation based indexing lookups (More documentation at 315 # http://docs.h5py.org/en/latest/high/dataset.html#fancy-indexing) 316 def DISABLED_test_hdf5_data_feeder(self): 317 318 def func(df): 319 inp, out = df.input_builder() 320 feed_dict_fn = df.get_feed_dict_fn() 321 feed_dict = feed_dict_fn() 322 self._assertAllClose(inp, [[3, 4], [1, 2]], feed_dict, 'name') 323 self.assertAllClose(out, [2, 1], feed_dict, 'name') 324 325 try: 326 import h5py # pylint: disable=g-import-not-at-top 327 x = np.matrix([[1, 2], [3, 4]]) 328 y = np.array([1, 2]) 329 file_path = os.path.join(self._base_dir, 'test_hdf5.h5') 330 h5f = h5py.File(file_path, 'w') 331 h5f.create_dataset('x', data=x) 332 h5f.create_dataset('y', data=y) 333 h5f.close() 334 h5f = h5py.File(file_path, 'r') 335 x = h5f['x'] 336 y = h5f['y'] 337 func(data_feeder.DataFeeder(x, y, n_classes=0, batch_size=3)) 338 func( 339 data_feeder.DataFeeder( 340 self._wrap_dict(x, 'in'), 341 self._wrap_dict(y, 'out'), 342 n_classes=self._wrap_dict(0, 'out'), 343 batch_size=3)) 344 except ImportError: 345 print("Skipped test for hdf5 since it's not installed.") 346 347 348class SetupPredictDataFeederTest(DataFeederTest): 349 """Tests for `DataFeeder.setup_predict_data_feeder`.""" 350 351 def test_iterable_data(self): 352 # pylint: disable=undefined-variable 353 354 def func(df): 355 self._assertAllClose(six.next(df), [[1, 2], [3, 4]]) 356 self._assertAllClose(six.next(df), [[5, 6]]) 357 358 data = [[1, 2], [3, 4], [5, 6]] 359 x = iter(data) 360 x_dict = iter([self._wrap_dict(v) for v in iter(data)]) 361 func(data_feeder.setup_predict_data_feeder(x, batch_size=2)) 362 func(data_feeder.setup_predict_data_feeder(x_dict, batch_size=2)) 363 364 365if __name__ == '__main__': 366 test.main() 367