• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for data_utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from itertools import cycle
22import os
23import tarfile
24import zipfile
25
26import numpy as np
27from six.moves.urllib.parse import urljoin
28from six.moves.urllib.request import pathname2url
29
30from tensorflow.python import keras
31from tensorflow.python.keras.utils import data_utils
32from tensorflow.python.platform import test
33
34
35class TestGetFileAndValidateIt(test.TestCase):
36
37  def test_get_file_and_validate_it(self):
38    """Tests get_file from a url, plus extraction and validation.
39    """
40    dest_dir = self.get_temp_dir()
41    orig_dir = self.get_temp_dir()
42
43    text_file_path = os.path.join(orig_dir, 'test.txt')
44    zip_file_path = os.path.join(orig_dir, 'test.zip')
45    tar_file_path = os.path.join(orig_dir, 'test.tar.gz')
46
47    with open(text_file_path, 'w') as text_file:
48      text_file.write('Float like a butterfly, sting like a bee.')
49
50    with tarfile.open(tar_file_path, 'w:gz') as tar_file:
51      tar_file.add(text_file_path)
52
53    with zipfile.ZipFile(zip_file_path, 'w') as zip_file:
54      zip_file.write(text_file_path)
55
56    origin = urljoin('file://', pathname2url(os.path.abspath(tar_file_path)))
57
58    path = keras.utils.data_utils.get_file('test.txt', origin,
59                                           untar=True, cache_subdir=dest_dir)
60    filepath = path + '.tar.gz'
61    hashval_sha256 = keras.utils.data_utils._hash_file(filepath)
62    hashval_md5 = keras.utils.data_utils._hash_file(filepath, algorithm='md5')
63    path = keras.utils.data_utils.get_file(
64        'test.txt', origin, md5_hash=hashval_md5,
65        untar=True, cache_subdir=dest_dir)
66    path = keras.utils.data_utils.get_file(
67        filepath, origin, file_hash=hashval_sha256,
68        extract=True, cache_subdir=dest_dir)
69    self.assertTrue(os.path.exists(filepath))
70    self.assertTrue(keras.utils.data_utils.validate_file(filepath,
71                                                         hashval_sha256))
72    self.assertTrue(keras.utils.data_utils.validate_file(filepath, hashval_md5))
73    os.remove(filepath)
74
75    origin = urljoin('file://', pathname2url(os.path.abspath(zip_file_path)))
76
77    hashval_sha256 = keras.utils.data_utils._hash_file(zip_file_path)
78    hashval_md5 = keras.utils.data_utils._hash_file(zip_file_path,
79                                                    algorithm='md5')
80    path = keras.utils.data_utils.get_file(
81        'test', origin, md5_hash=hashval_md5,
82        extract=True, cache_subdir=dest_dir)
83    path = keras.utils.data_utils.get_file(
84        'test', origin, file_hash=hashval_sha256,
85        extract=True, cache_subdir=dest_dir)
86    self.assertTrue(os.path.exists(path))
87    self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_sha256))
88    self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_md5))
89
90
91class TestSequence(keras.utils.data_utils.Sequence):
92
93  def __init__(self, shape, value=1.):
94    self.shape = shape
95    self.inner = value
96
97  def __getitem__(self, item):
98    return np.ones(self.shape, dtype=np.uint32) * item * self.inner
99
100  def __len__(self):
101    return 100
102
103  def on_epoch_end(self):
104    self.inner *= 5.0
105
106
107class FaultSequence(keras.utils.data_utils.Sequence):
108
109  def __getitem__(self, item):
110    raise IndexError(item, 'item is not present')
111
112  def __len__(self):
113    return 100
114
115
116@data_utils.threadsafe_generator
117def create_generator_from_sequence_threads(ds):
118  for i in cycle(range(len(ds))):
119    yield ds[i]
120
121
122def create_generator_from_sequence_pcs(ds):
123  for i in cycle(range(len(ds))):
124    yield ds[i]
125
126
127class TestEnqueuers(test.TestCase):
128
129  def test_generator_enqueuer_threads(self):
130    enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
131        create_generator_from_sequence_threads(TestSequence([3, 200, 200, 3])),
132        use_multiprocessing=False)
133    enqueuer.start(3, 10)
134    gen_output = enqueuer.get()
135    acc = []
136    for _ in range(100):
137      acc.append(int(next(gen_output)[0, 0, 0, 0]))
138
139    self.assertEqual(len(set(acc) - set(range(100))), 0)
140    enqueuer.stop()
141
142  @data_utils.dont_use_multiprocessing_pool
143  def test_generator_enqueuer_processes(self):
144    enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
145        create_generator_from_sequence_threads(TestSequence([3, 200, 200, 3])),
146        use_multiprocessing=True)
147    enqueuer.start(4, 10)
148    gen_output = enqueuer.get()
149    acc = []
150    for _ in range(300):
151      acc.append(int(next(gen_output)[0, 0, 0, 0]))
152    self.assertNotEqual(acc, list(range(100)))
153    enqueuer.stop()
154
155  def test_generator_enqueuer_fail_threads(self):
156    enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
157        create_generator_from_sequence_threads(FaultSequence()),
158        use_multiprocessing=False)
159    enqueuer.start(3, 10)
160    gen_output = enqueuer.get()
161    with self.assertRaises(IndexError):
162      next(gen_output)
163
164  @data_utils.dont_use_multiprocessing_pool
165  def test_generator_enqueuer_fail_processes(self):
166    enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
167        create_generator_from_sequence_threads(FaultSequence()),
168        use_multiprocessing=True)
169    enqueuer.start(3, 10)
170    gen_output = enqueuer.get()
171    with self.assertRaises(IndexError):
172      next(gen_output)
173
174  def test_ordered_enqueuer_threads(self):
175    enqueuer = keras.utils.data_utils.OrderedEnqueuer(
176        TestSequence([3, 200, 200, 3]), use_multiprocessing=False)
177    enqueuer.start(3, 10)
178    gen_output = enqueuer.get()
179    acc = []
180    for _ in range(100):
181      acc.append(next(gen_output)[0, 0, 0, 0])
182    self.assertEqual(acc, list(range(100)))
183    enqueuer.stop()
184
185  @data_utils.dont_use_multiprocessing_pool
186  def test_ordered_enqueuer_processes(self):
187    enqueuer = keras.utils.data_utils.OrderedEnqueuer(
188        TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
189    enqueuer.start(3, 10)
190    gen_output = enqueuer.get()
191    acc = []
192    for _ in range(100):
193      acc.append(next(gen_output)[0, 0, 0, 0])
194    self.assertEqual(acc, list(range(100)))
195    enqueuer.stop()
196
197  def test_ordered_enqueuer_fail_threads(self):
198    enqueuer = keras.utils.data_utils.OrderedEnqueuer(
199        FaultSequence(), use_multiprocessing=False)
200    enqueuer.start(3, 10)
201    gen_output = enqueuer.get()
202    with self.assertRaises(IndexError):
203      next(gen_output)
204
205  @data_utils.dont_use_multiprocessing_pool
206  def test_ordered_enqueuer_fail_processes(self):
207    enqueuer = keras.utils.data_utils.OrderedEnqueuer(
208        FaultSequence(), use_multiprocessing=True)
209    enqueuer.start(3, 10)
210    gen_output = enqueuer.get()
211    with self.assertRaises(IndexError):
212      next(gen_output)
213
214  @data_utils.dont_use_multiprocessing_pool
215  def test_on_epoch_end_processes(self):
216    enqueuer = keras.utils.data_utils.OrderedEnqueuer(
217        TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
218    enqueuer.start(3, 10)
219    gen_output = enqueuer.get()
220    acc = []
221    for _ in range(200):
222      acc.append(next(gen_output)[0, 0, 0, 0])
223    # Check that order was keep in GeneratorEnqueuer with processes
224    self.assertEqual(acc[100:], list([k * 5 for k in range(100)]))
225    enqueuer.stop()
226
227  @data_utils.dont_use_multiprocessing_pool
228  def test_context_switch(self):
229    enqueuer = keras.utils.data_utils.OrderedEnqueuer(
230        TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
231    enqueuer2 = keras.utils.data_utils.OrderedEnqueuer(
232        TestSequence([3, 200, 200, 3], value=15), use_multiprocessing=True)
233    enqueuer.start(3, 10)
234    enqueuer2.start(3, 10)
235    gen_output = enqueuer.get()
236    gen_output2 = enqueuer2.get()
237    acc = []
238    for _ in range(100):
239      acc.append(next(gen_output)[0, 0, 0, 0])
240    self.assertEqual(acc[-1], 99)
241    # One epoch is completed so enqueuer will switch the Sequence
242
243    acc = []
244    self.skipTest('b/145555807 flakily timing out.')
245    for _ in range(100):
246      acc.append(next(gen_output2)[0, 0, 0, 0])
247    self.assertEqual(acc[-1], 99 * 15)
248    # One epoch has been completed so enqueuer2 will switch
249
250    # Be sure that both Sequence were updated
251    self.assertEqual(next(gen_output)[0, 0, 0, 0], 0)
252    self.assertEqual(next(gen_output)[0, 0, 0, 0], 5)
253    self.assertEqual(next(gen_output2)[0, 0, 0, 0], 0)
254    self.assertEqual(next(gen_output2)[0, 0, 0, 0], 15 * 5)
255
256    # Tear down everything
257    enqueuer.stop()
258    enqueuer2.stop()
259
260  def test_on_epoch_end_threads(self):
261    enqueuer = keras.utils.data_utils.OrderedEnqueuer(
262        TestSequence([3, 200, 200, 3]), use_multiprocessing=False)
263    enqueuer.start(3, 10)
264    gen_output = enqueuer.get()
265    acc = []
266    for _ in range(100):
267      acc.append(next(gen_output)[0, 0, 0, 0])
268    acc = []
269    for _ in range(100):
270      acc.append(next(gen_output)[0, 0, 0, 0])
271    # Check that order was keep in GeneratorEnqueuer with processes
272    self.assertEqual(acc, list([k * 5 for k in range(100)]))
273    enqueuer.stop()
274
275
276if __name__ == '__main__':
277  # Bazel sets these environment variables to very long paths.
278  # Tempfile uses them to create long paths, and in turn multiprocessing
279  # library tries to create sockets named after paths. Delete whatever bazel
280  # writes to these to avoid tests failing due to socket addresses being too
281  # long.
282  for var in ('TMPDIR', 'TMP', 'TEMP'):
283    if var in os.environ:
284      del os.environ[var]
285
286  test.main()
287