1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-import-not-at-top 16"""Utilities for file download and caching.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from abc import abstractmethod 22from contextlib import closing 23import hashlib 24import multiprocessing 25from multiprocessing.pool import ThreadPool 26import os 27import random 28import shutil 29import sys 30import tarfile 31import threading 32import time 33import zipfile 34 35import numpy as np 36import six 37from six.moves.urllib.error import HTTPError 38from six.moves.urllib.error import URLError 39from six.moves.urllib.request import urlopen 40 41from tensorflow.python.keras.utils.generic_utils import Progbar 42from tensorflow.python.util import tf_inspect 43from tensorflow.python.util.tf_export import keras_export 44 45 46try: 47 import queue 48except ImportError: 49 import Queue as queue 50 51 52if sys.version_info[0] == 2: 53 54 def urlretrieve(url, filename, reporthook=None, data=None): 55 """Replacement for `urlretrive` for Python 2. 56 57 Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy 58 `urllib` module, known to have issues with proxy management. 59 60 Arguments: 61 url: url to retrieve. 62 filename: where to store the retrieved data locally. 63 reporthook: a hook function that will be called once 64 on establishment of the network connection and once 65 after each block read thereafter. 66 The hook will be passed three arguments; 67 a count of blocks transferred so far, 68 a block size in bytes, and the total size of the file. 69 data: `data` argument passed to `urlopen`. 70 """ 71 72 def chunk_read(response, chunk_size=8192, reporthook=None): 73 content_type = response.info().get('Content-Length') 74 total_size = -1 75 if content_type is not None: 76 total_size = int(content_type.strip()) 77 count = 0 78 while True: 79 chunk = response.read(chunk_size) 80 count += 1 81 if reporthook is not None: 82 reporthook(count, chunk_size, total_size) 83 if chunk: 84 yield chunk 85 else: 86 break 87 88 response = urlopen(url, data) 89 with open(filename, 'wb') as fd: 90 for chunk in chunk_read(response, reporthook=reporthook): 91 fd.write(chunk) 92else: 93 from six.moves.urllib.request import urlretrieve 94 95 96def is_generator_or_sequence(x): 97 """Check if `x` is a Keras generator type.""" 98 return tf_inspect.isgenerator(x) or isinstance(x, Sequence) 99 100 101def _extract_archive(file_path, path='.', archive_format='auto'): 102 """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats. 103 104 Arguments: 105 file_path: path to the archive file 106 path: path to extract the archive file 107 archive_format: Archive format to try for extracting the file. 108 Options are 'auto', 'tar', 'zip', and None. 109 'tar' includes tar, tar.gz, and tar.bz files. 110 The default 'auto' is ['tar', 'zip']. 111 None or an empty list will return no matches found. 112 113 Returns: 114 True if a match was found and an archive extraction was completed, 115 False otherwise. 116 """ 117 if archive_format is None: 118 return False 119 if archive_format == 'auto': 120 archive_format = ['tar', 'zip'] 121 if isinstance(archive_format, six.string_types): 122 archive_format = [archive_format] 123 124 for archive_type in archive_format: 125 if archive_type == 'tar': 126 open_fn = tarfile.open 127 is_match_fn = tarfile.is_tarfile 128 if archive_type == 'zip': 129 open_fn = zipfile.ZipFile 130 is_match_fn = zipfile.is_zipfile 131 132 if is_match_fn(file_path): 133 with open_fn(file_path) as archive: 134 try: 135 archive.extractall(path) 136 except (tarfile.TarError, RuntimeError, KeyboardInterrupt): 137 if os.path.exists(path): 138 if os.path.isfile(path): 139 os.remove(path) 140 else: 141 shutil.rmtree(path) 142 raise 143 return True 144 return False 145 146 147@keras_export('keras.utils.get_file') 148def get_file(fname, 149 origin, 150 untar=False, 151 md5_hash=None, 152 file_hash=None, 153 cache_subdir='datasets', 154 hash_algorithm='auto', 155 extract=False, 156 archive_format='auto', 157 cache_dir=None): 158 """Downloads a file from a URL if it not already in the cache. 159 160 By default the file at the url `origin` is downloaded to the 161 cache_dir `~/.keras`, placed in the cache_subdir `datasets`, 162 and given the filename `fname`. The final location of a file 163 `example.txt` would therefore be `~/.keras/datasets/example.txt`. 164 165 Files in tar, tar.gz, tar.bz, and zip formats can also be extracted. 166 Passing a hash will verify the file after download. The command line 167 programs `shasum` and `sha256sum` can compute the hash. 168 169 Arguments: 170 fname: Name of the file. If an absolute path `/path/to/file.txt` is 171 specified the file will be saved at that location. 172 origin: Original URL of the file. 173 untar: Deprecated in favor of 'extract'. 174 boolean, whether the file should be decompressed 175 md5_hash: Deprecated in favor of 'file_hash'. 176 md5 hash of the file for verification 177 file_hash: The expected hash string of the file after download. 178 The sha256 and md5 hash algorithms are both supported. 179 cache_subdir: Subdirectory under the Keras cache dir where the file is 180 saved. If an absolute path `/path/to/folder` is 181 specified the file will be saved at that location. 182 hash_algorithm: Select the hash algorithm to verify the file. 183 options are 'md5', 'sha256', and 'auto'. 184 The default 'auto' detects the hash algorithm in use. 185 extract: True tries extracting the file as an Archive, like tar or zip. 186 archive_format: Archive format to try for extracting the file. 187 Options are 'auto', 'tar', 'zip', and None. 188 'tar' includes tar, tar.gz, and tar.bz files. 189 The default 'auto' is ['tar', 'zip']. 190 None or an empty list will return no matches found. 191 cache_dir: Location to store cached files, when None it 192 defaults to the [Keras 193 Directory](/faq/#where-is-the-keras-configuration-filed-stored). 194 195 Returns: 196 Path to the downloaded file 197 """ 198 if cache_dir is None: 199 cache_dir = os.path.join(os.path.expanduser('~'), '.keras') 200 if md5_hash is not None and file_hash is None: 201 file_hash = md5_hash 202 hash_algorithm = 'md5' 203 datadir_base = os.path.expanduser(cache_dir) 204 if not os.access(datadir_base, os.W_OK): 205 datadir_base = os.path.join('/tmp', '.keras') 206 datadir = os.path.join(datadir_base, cache_subdir) 207 if not os.path.exists(datadir): 208 os.makedirs(datadir) 209 210 if untar: 211 untar_fpath = os.path.join(datadir, fname) 212 fpath = untar_fpath + '.tar.gz' 213 else: 214 fpath = os.path.join(datadir, fname) 215 216 download = False 217 if os.path.exists(fpath): 218 # File found; verify integrity if a hash was provided. 219 if file_hash is not None: 220 if not validate_file(fpath, file_hash, algorithm=hash_algorithm): 221 print('A local file was found, but it seems to be ' 222 'incomplete or outdated because the ' + hash_algorithm + 223 ' file hash does not match the original value of ' + file_hash + 224 ' so we will re-download the data.') 225 download = True 226 else: 227 download = True 228 229 if download: 230 print('Downloading data from', origin) 231 232 class ProgressTracker(object): 233 # Maintain progbar for the lifetime of download. 234 # This design was chosen for Python 2.7 compatibility. 235 progbar = None 236 237 def dl_progress(count, block_size, total_size): 238 if ProgressTracker.progbar is None: 239 if total_size == -1: 240 total_size = None 241 ProgressTracker.progbar = Progbar(total_size) 242 else: 243 ProgressTracker.progbar.update(count * block_size) 244 245 error_msg = 'URL fetch failure on {}: {} -- {}' 246 try: 247 try: 248 urlretrieve(origin, fpath, dl_progress) 249 except HTTPError as e: 250 raise Exception(error_msg.format(origin, e.code, e.msg)) 251 except URLError as e: 252 raise Exception(error_msg.format(origin, e.errno, e.reason)) 253 except (Exception, KeyboardInterrupt) as e: 254 if os.path.exists(fpath): 255 os.remove(fpath) 256 raise 257 ProgressTracker.progbar = None 258 259 if untar: 260 if not os.path.exists(untar_fpath): 261 _extract_archive(fpath, datadir, archive_format='tar') 262 return untar_fpath 263 264 if extract: 265 _extract_archive(fpath, datadir, archive_format) 266 267 return fpath 268 269 270def _hash_file(fpath, algorithm='sha256', chunk_size=65535): 271 """Calculates a file sha256 or md5 hash. 272 273 Example: 274 275 ```python 276 >>> from keras.data_utils import _hash_file 277 >>> _hash_file('/path/to/file.zip') 278 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' 279 ``` 280 281 Arguments: 282 fpath: path to the file being validated 283 algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'. 284 The default 'auto' detects the hash algorithm in use. 285 chunk_size: Bytes to read at a time, important for large files. 286 287 Returns: 288 The file hash 289 """ 290 if (algorithm == 'sha256') or (algorithm == 'auto' and len(hash) == 64): 291 hasher = hashlib.sha256() 292 else: 293 hasher = hashlib.md5() 294 295 with open(fpath, 'rb') as fpath_file: 296 for chunk in iter(lambda: fpath_file.read(chunk_size), b''): 297 hasher.update(chunk) 298 299 return hasher.hexdigest() 300 301 302def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535): 303 """Validates a file against a sha256 or md5 hash. 304 305 Arguments: 306 fpath: path to the file being validated 307 file_hash: The expected hash string of the file. 308 The sha256 and md5 hash algorithms are both supported. 309 algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'. 310 The default 'auto' detects the hash algorithm in use. 311 chunk_size: Bytes to read at a time, important for large files. 312 313 Returns: 314 Whether the file is valid 315 """ 316 if (algorithm == 'sha256') or (algorithm == 'auto' and len(file_hash) == 64): 317 hasher = 'sha256' 318 else: 319 hasher = 'md5' 320 321 if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash): 322 return True 323 else: 324 return False 325 326 327@keras_export('keras.utils.Sequence') 328class Sequence(object): 329 """Base object for fitting to a sequence of data, such as a dataset. 330 331 Every `Sequence` must implement the `__getitem__` and the `__len__` methods. 332 If you want to modify your dataset between epochs you may implement 333 `on_epoch_end`. 334 The method `__getitem__` should return a complete batch. 335 336 Notes: 337 338 `Sequence` are a safer way to do multiprocessing. This structure guarantees 339 that the network will only train once 340 on each sample per epoch which is not the case with generators. 341 342 Examples: 343 344 ```python 345 from skimage.io import imread 346 from skimage.transform import resize 347 import numpy as np 348 import math 349 350 # Here, `x_set` is list of path to the images 351 # and `y_set` are the associated classes. 352 353 class CIFAR10Sequence(Sequence): 354 355 def __init__(self, x_set, y_set, batch_size): 356 self.x, self.y = x_set, y_set 357 self.batch_size = batch_size 358 359 def __len__(self): 360 return math.ceil(len(self.x) / self.batch_size) 361 362 def __getitem__(self, idx): 363 batch_x = self.x[idx * self.batch_size:(idx + 1) * 364 self.batch_size] 365 batch_y = self.y[idx * self.batch_size:(idx + 1) * 366 self.batch_size] 367 368 return np.array([ 369 resize(imread(file_name), (200, 200)) 370 for file_name in batch_x]), np.array(batch_y) 371 ``` 372 """ 373 374 @abstractmethod 375 def __getitem__(self, index): 376 """Gets batch at position `index`. 377 378 Arguments: 379 index: position of the batch in the Sequence. 380 381 Returns: 382 A batch 383 """ 384 raise NotImplementedError 385 386 @abstractmethod 387 def __len__(self): 388 """Number of batch in the Sequence. 389 390 Returns: 391 The number of batches in the Sequence. 392 """ 393 raise NotImplementedError 394 395 def on_epoch_end(self): 396 """Method called at the end of every epoch. 397 """ 398 pass 399 400 def __iter__(self): 401 """Create a generator that iterate over the Sequence.""" 402 for item in (self[i] for i in range(len(self))): 403 yield item 404 405 406def iter_sequence_infinite(seq): 407 """Iterates indefinitely over a Sequence. 408 409 Arguments: 410 seq: Sequence instance. 411 412 Yields: 413 Batches of data from the Sequence. 414 """ 415 while True: 416 for item in seq: 417 yield item 418 419 420# Global variables to be shared across processes 421_SHARED_SEQUENCES = {} 422# We use a Value to provide unique id to different processes. 423_SEQUENCE_COUNTER = None 424 425 426def init_pool(seqs): 427 global _SHARED_SEQUENCES 428 _SHARED_SEQUENCES = seqs 429 430 431def get_index(uid, i): 432 """Get the value from the Sequence `uid` at index `i`. 433 434 To allow multiple Sequences to be used at the same time, we use `uid` to 435 get a specific one. A single Sequence would cause the validation to 436 overwrite the training Sequence. 437 438 Arguments: 439 uid: int, Sequence identifier 440 i: index 441 442 Returns: 443 The value at index `i`. 444 """ 445 return _SHARED_SEQUENCES[uid][i] 446 447 448@keras_export('keras.utils.SequenceEnqueuer') 449class SequenceEnqueuer(object): 450 """Base class to enqueue inputs. 451 452 The task of an Enqueuer is to use parallelism to speed up preprocessing. 453 This is done with processes or threads. 454 455 Example: 456 457 ```python 458 enqueuer = SequenceEnqueuer(...) 459 enqueuer.start() 460 datas = enqueuer.get() 461 for data in datas: 462 # Use the inputs; training, evaluating, predicting. 463 # ... stop sometime. 464 enqueuer.close() 465 ``` 466 467 The `enqueuer.get()` should be an infinite stream of datas. 468 """ 469 470 def __init__(self, sequence, 471 use_multiprocessing=False): 472 self.sequence = sequence 473 self.use_multiprocessing = use_multiprocessing 474 475 global _SEQUENCE_COUNTER 476 if _SEQUENCE_COUNTER is None: 477 try: 478 _SEQUENCE_COUNTER = multiprocessing.Value('i', 0) 479 except OSError: 480 # In this case the OS does not allow us to use 481 # multiprocessing. We resort to an int 482 # for enqueuer indexing. 483 _SEQUENCE_COUNTER = 0 484 485 if isinstance(_SEQUENCE_COUNTER, int): 486 self.uid = _SEQUENCE_COUNTER 487 _SEQUENCE_COUNTER += 1 488 else: 489 # Doing Multiprocessing.Value += x is not process-safe. 490 with _SEQUENCE_COUNTER.get_lock(): 491 self.uid = _SEQUENCE_COUNTER.value 492 _SEQUENCE_COUNTER.value += 1 493 494 self.workers = 0 495 self.executor_fn = None 496 self.queue = None 497 self.run_thread = None 498 self.stop_signal = None 499 500 def is_running(self): 501 return self.stop_signal is not None and not self.stop_signal.is_set() 502 503 def start(self, workers=1, max_queue_size=10): 504 """Starts the handler's workers. 505 506 Arguments: 507 workers: Number of workers. 508 max_queue_size: queue size 509 (when full, workers could block on `put()`) 510 """ 511 if self.use_multiprocessing: 512 self.executor_fn = self._get_executor_init(workers) 513 else: 514 # We do not need the init since it's threads. 515 self.executor_fn = lambda _: ThreadPool(workers) 516 self.workers = workers 517 self.queue = queue.Queue(max_queue_size) 518 self.stop_signal = threading.Event() 519 self.run_thread = threading.Thread(target=self._run) 520 self.run_thread.daemon = True 521 self.run_thread.start() 522 523 def _send_sequence(self): 524 """Sends current Iterable to all workers.""" 525 # For new processes that may spawn 526 _SHARED_SEQUENCES[self.uid] = self.sequence 527 528 def stop(self, timeout=None): 529 """Stops running threads and wait for them to exit, if necessary. 530 531 Should be called by the same thread which called `start()`. 532 533 Arguments: 534 timeout: maximum time to wait on `thread.join()` 535 """ 536 self.stop_signal.set() 537 with self.queue.mutex: 538 self.queue.queue.clear() 539 self.queue.unfinished_tasks = 0 540 self.queue.not_full.notify() 541 self.run_thread.join(timeout) 542 _SHARED_SEQUENCES[self.uid] = None 543 544 @abstractmethod 545 def _run(self): 546 """Submits request to the executor and queue the `Future` objects.""" 547 raise NotImplementedError 548 549 @abstractmethod 550 def _get_executor_init(self, workers): 551 """Gets the Pool initializer for multiprocessing. 552 553 Arguments: 554 workers: Number of workers. 555 556 Returns: 557 Function, a Function to initialize the pool 558 """ 559 raise NotImplementedError 560 561 @abstractmethod 562 def get(self): 563 """Creates a generator to extract data from the queue. 564 565 Skip the data if it is `None`. 566 # Returns 567 Generator yielding tuples `(inputs, targets)` 568 or `(inputs, targets, sample_weights)`. 569 """ 570 raise NotImplementedError 571 572 573@keras_export('keras.utils.OrderedEnqueuer') 574class OrderedEnqueuer(SequenceEnqueuer): 575 """Builds a Enqueuer from a Sequence. 576 577 Used in `fit_generator`, `evaluate_generator`, `predict_generator`. 578 579 Arguments: 580 sequence: A `tf.keras.utils.data_utils.Sequence` object. 581 use_multiprocessing: use multiprocessing if True, otherwise threading 582 shuffle: whether to shuffle the data at the beginning of each epoch 583 """ 584 585 def __init__(self, sequence, use_multiprocessing=False, shuffle=False): 586 super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing) 587 self.shuffle = shuffle 588 589 def _get_executor_init(self, workers): 590 """Gets the Pool initializer for multiprocessing. 591 592 Arguments: 593 workers: Number of workers. 594 595 Returns: 596 Function, a Function to initialize the pool 597 """ 598 def pool_fn(seqs): 599 return multiprocessing.Pool( 600 workers, initializer=init_pool_generator, initargs=(seqs, None)) 601 602 return pool_fn 603 604 def _wait_queue(self): 605 """Wait for the queue to be empty.""" 606 while True: 607 time.sleep(0.1) 608 if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set(): 609 return 610 611 def _run(self): 612 """Submits request to the executor and queue the `Future` objects.""" 613 sequence = list(range(len(self.sequence))) 614 self._send_sequence() # Share the initial sequence 615 while True: 616 if self.shuffle: 617 random.shuffle(sequence) 618 619 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: 620 for i in sequence: 621 if self.stop_signal.is_set(): 622 return 623 self.queue.put( 624 executor.apply_async(get_index, (self.uid, i)), block=True) 625 626 # Done with the current epoch, waiting for the final batches 627 self._wait_queue() 628 629 if self.stop_signal.is_set(): 630 # We're done 631 return 632 633 # Call the internal on epoch end. 634 self.sequence.on_epoch_end() 635 self._send_sequence() # Update the pool 636 637 def get(self): 638 """Creates a generator to extract data from the queue. 639 640 Skip the data if it is `None`. 641 642 Yields: 643 The next element in the queue, i.e. a tuple 644 `(inputs, targets)` or 645 `(inputs, targets, sample_weights)`. 646 """ 647 try: 648 while self.is_running(): 649 inputs = self.queue.get(block=True).get() 650 self.queue.task_done() 651 if inputs is not None: 652 yield inputs 653 except Exception: # pylint: disable=broad-except 654 self.stop() 655 six.reraise(*sys.exc_info()) 656 657 658def init_pool_generator(gens, random_seed=None): 659 global _SHARED_SEQUENCES 660 _SHARED_SEQUENCES = gens 661 662 if random_seed is not None: 663 ident = multiprocessing.current_process().ident 664 np.random.seed(random_seed + ident) 665 666 667def next_sample(uid): 668 """Gets the next value from the generator `uid`. 669 670 To allow multiple generators to be used at the same time, we use `uid` to 671 get a specific one. A single generator would cause the validation to 672 overwrite the training generator. 673 674 Arguments: 675 uid: int, generator identifier 676 677 Returns: 678 The next value of generator `uid`. 679 """ 680 return six.next(_SHARED_SEQUENCES[uid]) 681 682 683@keras_export('keras.utils.GeneratorEnqueuer') 684class GeneratorEnqueuer(SequenceEnqueuer): 685 """Builds a queue out of a data generator. 686 687 The provided generator can be finite in which case the class will throw 688 a `StopIteration` exception. 689 690 Used in `fit_generator`, `evaluate_generator`, `predict_generator`. 691 692 Arguments: 693 generator: a generator function which yields data 694 use_multiprocessing: use multiprocessing if True, otherwise threading 695 wait_time: time to sleep in-between calls to `put()` 696 random_seed: Initial seed for workers, 697 will be incremented by one for each worker. 698 """ 699 700 def __init__(self, sequence, 701 use_multiprocessing=False, 702 random_seed=None): 703 super(GeneratorEnqueuer, self).__init__(sequence, use_multiprocessing) 704 self.random_seed = random_seed 705 706 def _get_executor_init(self, workers): 707 """Gets the Pool initializer for multiprocessing. 708 709 Arguments: 710 workers: Number of works. 711 712 Returns: 713 A Function to initialize the pool 714 """ 715 def pool_fn(seqs): 716 return multiprocessing.Pool(workers, 717 initializer=init_pool_generator, 718 initargs=(seqs, self.random_seed)) 719 return pool_fn 720 721 def _run(self): 722 """Submits request to the executor and queue the `Future` objects.""" 723 self._send_sequence() # Share the initial generator 724 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: 725 while True: 726 if self.stop_signal.is_set(): 727 return 728 self.queue.put( 729 executor.apply_async(next_sample, (self.uid,)), block=True) 730 731 def get(self): 732 """Creates a generator to extract data from the queue. 733 734 Skip the data if it is `None`. 735 736 Yields: 737 The next element in the queue, i.e. a tuple 738 `(inputs, targets)` or 739 `(inputs, targets, sample_weights)`. 740 """ 741 try: 742 while self.is_running(): 743 inputs = self.queue.get(block=True).get() 744 self.queue.task_done() 745 if inputs is not None: 746 yield inputs 747 except StopIteration: 748 # Special case for finite generators 749 last_ones = [] 750 while self.queue.qsize() > 0: 751 last_ones.append(self.queue.get(block=True)) 752 # Wait for them to complete 753 for f in last_ones: 754 f.wait() 755 # Keep the good ones 756 last_ones = [future.get() for future in last_ones if future.successful()] 757 for inputs in last_ones: 758 if inputs is not None: 759 yield inputs 760 except Exception as e: # pylint: disable=broad-except 761 self.stop() 762 if 'generator already executing' in str(e): 763 raise RuntimeError( 764 'Your generator is NOT thread-safe. ' 765 'Keras requires a thread-safe generator when ' 766 '`use_multiprocessing=False, workers > 1`. ') 767 six.reraise(*sys.exc_info()) 768