1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for data_utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from itertools import cycle 22import os 23import tarfile 24import zipfile 25 26import numpy as np 27from six.moves.urllib.parse import urljoin 28from six.moves.urllib.request import pathname2url 29 30from tensorflow.python import keras 31from tensorflow.python.keras.utils import data_utils 32from tensorflow.python.platform import test 33 34 35class TestGetFileAndValidateIt(test.TestCase): 36 37 def test_get_file_and_validate_it(self): 38 """Tests get_file from a url, plus extraction and validation. 39 """ 40 dest_dir = self.get_temp_dir() 41 orig_dir = self.get_temp_dir() 42 43 text_file_path = os.path.join(orig_dir, 'test.txt') 44 zip_file_path = os.path.join(orig_dir, 'test.zip') 45 tar_file_path = os.path.join(orig_dir, 'test.tar.gz') 46 47 with open(text_file_path, 'w') as text_file: 48 text_file.write('Float like a butterfly, sting like a bee.') 49 50 with tarfile.open(tar_file_path, 'w:gz') as tar_file: 51 tar_file.add(text_file_path) 52 53 with zipfile.ZipFile(zip_file_path, 'w') as zip_file: 54 zip_file.write(text_file_path) 55 56 origin = urljoin('file://', pathname2url(os.path.abspath(tar_file_path))) 57 58 path = keras.utils.data_utils.get_file('test.txt', origin, 59 untar=True, cache_subdir=dest_dir) 60 filepath = path + '.tar.gz' 61 hashval_sha256 = keras.utils.data_utils._hash_file(filepath) 62 hashval_md5 = keras.utils.data_utils._hash_file(filepath, algorithm='md5') 63 path = keras.utils.data_utils.get_file( 64 'test.txt', origin, md5_hash=hashval_md5, 65 untar=True, cache_subdir=dest_dir) 66 path = keras.utils.data_utils.get_file( 67 filepath, origin, file_hash=hashval_sha256, 68 extract=True, cache_subdir=dest_dir) 69 self.assertTrue(os.path.exists(filepath)) 70 self.assertTrue(keras.utils.data_utils.validate_file(filepath, 71 hashval_sha256)) 72 self.assertTrue(keras.utils.data_utils.validate_file(filepath, hashval_md5)) 73 os.remove(filepath) 74 75 origin = urljoin('file://', pathname2url(os.path.abspath(zip_file_path))) 76 77 hashval_sha256 = keras.utils.data_utils._hash_file(zip_file_path) 78 hashval_md5 = keras.utils.data_utils._hash_file(zip_file_path, 79 algorithm='md5') 80 path = keras.utils.data_utils.get_file( 81 'test', origin, md5_hash=hashval_md5, 82 extract=True, cache_subdir=dest_dir) 83 path = keras.utils.data_utils.get_file( 84 'test', origin, file_hash=hashval_sha256, 85 extract=True, cache_subdir=dest_dir) 86 self.assertTrue(os.path.exists(path)) 87 self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_sha256)) 88 self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_md5)) 89 90 91class TestSequence(keras.utils.data_utils.Sequence): 92 93 def __init__(self, shape, value=1.): 94 self.shape = shape 95 self.inner = value 96 97 def __getitem__(self, item): 98 return np.ones(self.shape, dtype=np.uint32) * item * self.inner 99 100 def __len__(self): 101 return 100 102 103 def on_epoch_end(self): 104 self.inner *= 5.0 105 106 107class FaultSequence(keras.utils.data_utils.Sequence): 108 109 def __getitem__(self, item): 110 raise IndexError(item, 'item is not present') 111 112 def __len__(self): 113 return 100 114 115 116@data_utils.threadsafe_generator 117def create_generator_from_sequence_threads(ds): 118 for i in cycle(range(len(ds))): 119 yield ds[i] 120 121 122def create_generator_from_sequence_pcs(ds): 123 for i in cycle(range(len(ds))): 124 yield ds[i] 125 126 127class TestEnqueuers(test.TestCase): 128 129 def test_generator_enqueuer_threads(self): 130 enqueuer = keras.utils.data_utils.GeneratorEnqueuer( 131 create_generator_from_sequence_threads(TestSequence([3, 200, 200, 3])), 132 use_multiprocessing=False) 133 enqueuer.start(3, 10) 134 gen_output = enqueuer.get() 135 acc = [] 136 for _ in range(100): 137 acc.append(int(next(gen_output)[0, 0, 0, 0])) 138 139 self.assertEqual(len(set(acc) - set(range(100))), 0) 140 enqueuer.stop() 141 142 @data_utils.dont_use_multiprocessing_pool 143 def test_generator_enqueuer_processes(self): 144 enqueuer = keras.utils.data_utils.GeneratorEnqueuer( 145 create_generator_from_sequence_threads(TestSequence([3, 200, 200, 3])), 146 use_multiprocessing=True) 147 enqueuer.start(4, 10) 148 gen_output = enqueuer.get() 149 acc = [] 150 for _ in range(300): 151 acc.append(int(next(gen_output)[0, 0, 0, 0])) 152 self.assertNotEqual(acc, list(range(100))) 153 enqueuer.stop() 154 155 def test_generator_enqueuer_fail_threads(self): 156 enqueuer = keras.utils.data_utils.GeneratorEnqueuer( 157 create_generator_from_sequence_threads(FaultSequence()), 158 use_multiprocessing=False) 159 enqueuer.start(3, 10) 160 gen_output = enqueuer.get() 161 with self.assertRaises(IndexError): 162 next(gen_output) 163 164 @data_utils.dont_use_multiprocessing_pool 165 def test_generator_enqueuer_fail_processes(self): 166 enqueuer = keras.utils.data_utils.GeneratorEnqueuer( 167 create_generator_from_sequence_threads(FaultSequence()), 168 use_multiprocessing=True) 169 enqueuer.start(3, 10) 170 gen_output = enqueuer.get() 171 with self.assertRaises(IndexError): 172 next(gen_output) 173 174 def test_ordered_enqueuer_threads(self): 175 enqueuer = keras.utils.data_utils.OrderedEnqueuer( 176 TestSequence([3, 200, 200, 3]), use_multiprocessing=False) 177 enqueuer.start(3, 10) 178 gen_output = enqueuer.get() 179 acc = [] 180 for _ in range(100): 181 acc.append(next(gen_output)[0, 0, 0, 0]) 182 self.assertEqual(acc, list(range(100))) 183 enqueuer.stop() 184 185 @data_utils.dont_use_multiprocessing_pool 186 def test_ordered_enqueuer_processes(self): 187 enqueuer = keras.utils.data_utils.OrderedEnqueuer( 188 TestSequence([3, 200, 200, 3]), use_multiprocessing=True) 189 enqueuer.start(3, 10) 190 gen_output = enqueuer.get() 191 acc = [] 192 for _ in range(100): 193 acc.append(next(gen_output)[0, 0, 0, 0]) 194 self.assertEqual(acc, list(range(100))) 195 enqueuer.stop() 196 197 def test_ordered_enqueuer_fail_threads(self): 198 enqueuer = keras.utils.data_utils.OrderedEnqueuer( 199 FaultSequence(), use_multiprocessing=False) 200 enqueuer.start(3, 10) 201 gen_output = enqueuer.get() 202 with self.assertRaises(IndexError): 203 next(gen_output) 204 205 @data_utils.dont_use_multiprocessing_pool 206 def test_ordered_enqueuer_fail_processes(self): 207 enqueuer = keras.utils.data_utils.OrderedEnqueuer( 208 FaultSequence(), use_multiprocessing=True) 209 enqueuer.start(3, 10) 210 gen_output = enqueuer.get() 211 with self.assertRaises(IndexError): 212 next(gen_output) 213 214 @data_utils.dont_use_multiprocessing_pool 215 def test_on_epoch_end_processes(self): 216 enqueuer = keras.utils.data_utils.OrderedEnqueuer( 217 TestSequence([3, 200, 200, 3]), use_multiprocessing=True) 218 enqueuer.start(3, 10) 219 gen_output = enqueuer.get() 220 acc = [] 221 for _ in range(200): 222 acc.append(next(gen_output)[0, 0, 0, 0]) 223 # Check that order was keep in GeneratorEnqueuer with processes 224 self.assertEqual(acc[100:], list([k * 5 for k in range(100)])) 225 enqueuer.stop() 226 227 @data_utils.dont_use_multiprocessing_pool 228 def test_context_switch(self): 229 enqueuer = keras.utils.data_utils.OrderedEnqueuer( 230 TestSequence([3, 200, 200, 3]), use_multiprocessing=True) 231 enqueuer2 = keras.utils.data_utils.OrderedEnqueuer( 232 TestSequence([3, 200, 200, 3], value=15), use_multiprocessing=True) 233 enqueuer.start(3, 10) 234 enqueuer2.start(3, 10) 235 gen_output = enqueuer.get() 236 gen_output2 = enqueuer2.get() 237 acc = [] 238 for _ in range(100): 239 acc.append(next(gen_output)[0, 0, 0, 0]) 240 self.assertEqual(acc[-1], 99) 241 # One epoch is completed so enqueuer will switch the Sequence 242 243 acc = [] 244 self.skipTest('b/145555807 flakily timing out.') 245 for _ in range(100): 246 acc.append(next(gen_output2)[0, 0, 0, 0]) 247 self.assertEqual(acc[-1], 99 * 15) 248 # One epoch has been completed so enqueuer2 will switch 249 250 # Be sure that both Sequence were updated 251 self.assertEqual(next(gen_output)[0, 0, 0, 0], 0) 252 self.assertEqual(next(gen_output)[0, 0, 0, 0], 5) 253 self.assertEqual(next(gen_output2)[0, 0, 0, 0], 0) 254 self.assertEqual(next(gen_output2)[0, 0, 0, 0], 15 * 5) 255 256 # Tear down everything 257 enqueuer.stop() 258 enqueuer2.stop() 259 260 def test_on_epoch_end_threads(self): 261 enqueuer = keras.utils.data_utils.OrderedEnqueuer( 262 TestSequence([3, 200, 200, 3]), use_multiprocessing=False) 263 enqueuer.start(3, 10) 264 gen_output = enqueuer.get() 265 acc = [] 266 for _ in range(100): 267 acc.append(next(gen_output)[0, 0, 0, 0]) 268 acc = [] 269 for _ in range(100): 270 acc.append(next(gen_output)[0, 0, 0, 0]) 271 # Check that order was keep in GeneratorEnqueuer with processes 272 self.assertEqual(acc, list([k * 5 for k in range(100)])) 273 enqueuer.stop() 274 275 276if __name__ == '__main__': 277 # Bazel sets these environment variables to very long paths. 278 # Tempfile uses them to create long paths, and in turn multiprocessing 279 # library tries to create sockets named after paths. Delete whatever bazel 280 # writes to these to avoid tests failing due to socket addresses being too 281 # long. 282 for var in ('TMPDIR', 'TMP', 'TEMP'): 283 if var in os.environ: 284 del os.environ[var] 285 286 test.main() 287