• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-import-not-at-top
16"""Utilities for file download and caching."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from abc import abstractmethod
22from contextlib import closing
23import hashlib
24import multiprocessing
25from multiprocessing.pool import ThreadPool
26import os
27import random
28import shutil
29import sys
30import tarfile
31import threading
32import time
33import zipfile
34
35import numpy as np
36import six
37from six.moves.urllib.error import HTTPError
38from six.moves.urllib.error import URLError
39from six.moves.urllib.request import urlopen
40
41from tensorflow.python.keras.utils.generic_utils import Progbar
42from tensorflow.python.util import tf_inspect
43from tensorflow.python.util.tf_export import keras_export
44
45
46try:
47  import queue
48except ImportError:
49  import Queue as queue
50
51
52if sys.version_info[0] == 2:
53
54  def urlretrieve(url, filename, reporthook=None, data=None):
55    """Replacement for `urlretrive` for Python 2.
56
57    Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy
58    `urllib` module, known to have issues with proxy management.
59
60    Arguments:
61        url: url to retrieve.
62        filename: where to store the retrieved data locally.
63        reporthook: a hook function that will be called once
64            on establishment of the network connection and once
65            after each block read thereafter.
66            The hook will be passed three arguments;
67            a count of blocks transferred so far,
68            a block size in bytes, and the total size of the file.
69        data: `data` argument passed to `urlopen`.
70    """
71
72    def chunk_read(response, chunk_size=8192, reporthook=None):
73      content_type = response.info().get('Content-Length')
74      total_size = -1
75      if content_type is not None:
76        total_size = int(content_type.strip())
77      count = 0
78      while True:
79        chunk = response.read(chunk_size)
80        count += 1
81        if reporthook is not None:
82          reporthook(count, chunk_size, total_size)
83        if chunk:
84          yield chunk
85        else:
86          break
87
88    response = urlopen(url, data)
89    with open(filename, 'wb') as fd:
90      for chunk in chunk_read(response, reporthook=reporthook):
91        fd.write(chunk)
92else:
93  from six.moves.urllib.request import urlretrieve
94
95
96def is_generator_or_sequence(x):
97  """Check if `x` is a Keras generator type."""
98  return tf_inspect.isgenerator(x) or isinstance(x, Sequence)
99
100
101def _extract_archive(file_path, path='.', archive_format='auto'):
102  """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
103
104  Arguments:
105      file_path: path to the archive file
106      path: path to extract the archive file
107      archive_format: Archive format to try for extracting the file.
108          Options are 'auto', 'tar', 'zip', and None.
109          'tar' includes tar, tar.gz, and tar.bz files.
110          The default 'auto' is ['tar', 'zip'].
111          None or an empty list will return no matches found.
112
113  Returns:
114      True if a match was found and an archive extraction was completed,
115      False otherwise.
116  """
117  if archive_format is None:
118    return False
119  if archive_format == 'auto':
120    archive_format = ['tar', 'zip']
121  if isinstance(archive_format, six.string_types):
122    archive_format = [archive_format]
123
124  for archive_type in archive_format:
125    if archive_type == 'tar':
126      open_fn = tarfile.open
127      is_match_fn = tarfile.is_tarfile
128    if archive_type == 'zip':
129      open_fn = zipfile.ZipFile
130      is_match_fn = zipfile.is_zipfile
131
132    if is_match_fn(file_path):
133      with open_fn(file_path) as archive:
134        try:
135          archive.extractall(path)
136        except (tarfile.TarError, RuntimeError, KeyboardInterrupt):
137          if os.path.exists(path):
138            if os.path.isfile(path):
139              os.remove(path)
140            else:
141              shutil.rmtree(path)
142          raise
143      return True
144  return False
145
146
147@keras_export('keras.utils.get_file')
148def get_file(fname,
149             origin,
150             untar=False,
151             md5_hash=None,
152             file_hash=None,
153             cache_subdir='datasets',
154             hash_algorithm='auto',
155             extract=False,
156             archive_format='auto',
157             cache_dir=None):
158  """Downloads a file from a URL if it not already in the cache.
159
160  By default the file at the url `origin` is downloaded to the
161  cache_dir `~/.keras`, placed in the cache_subdir `datasets`,
162  and given the filename `fname`. The final location of a file
163  `example.txt` would therefore be `~/.keras/datasets/example.txt`.
164
165  Files in tar, tar.gz, tar.bz, and zip formats can also be extracted.
166  Passing a hash will verify the file after download. The command line
167  programs `shasum` and `sha256sum` can compute the hash.
168
169  Arguments:
170      fname: Name of the file. If an absolute path `/path/to/file.txt` is
171          specified the file will be saved at that location.
172      origin: Original URL of the file.
173      untar: Deprecated in favor of 'extract'.
174          boolean, whether the file should be decompressed
175      md5_hash: Deprecated in favor of 'file_hash'.
176          md5 hash of the file for verification
177      file_hash: The expected hash string of the file after download.
178          The sha256 and md5 hash algorithms are both supported.
179      cache_subdir: Subdirectory under the Keras cache dir where the file is
180          saved. If an absolute path `/path/to/folder` is
181          specified the file will be saved at that location.
182      hash_algorithm: Select the hash algorithm to verify the file.
183          options are 'md5', 'sha256', and 'auto'.
184          The default 'auto' detects the hash algorithm in use.
185      extract: True tries extracting the file as an Archive, like tar or zip.
186      archive_format: Archive format to try for extracting the file.
187          Options are 'auto', 'tar', 'zip', and None.
188          'tar' includes tar, tar.gz, and tar.bz files.
189          The default 'auto' is ['tar', 'zip'].
190          None or an empty list will return no matches found.
191      cache_dir: Location to store cached files, when None it
192          defaults to the [Keras
193            Directory](/faq/#where-is-the-keras-configuration-filed-stored).
194
195  Returns:
196      Path to the downloaded file
197  """
198  if cache_dir is None:
199    cache_dir = os.path.join(os.path.expanduser('~'), '.keras')
200  if md5_hash is not None and file_hash is None:
201    file_hash = md5_hash
202    hash_algorithm = 'md5'
203  datadir_base = os.path.expanduser(cache_dir)
204  if not os.access(datadir_base, os.W_OK):
205    datadir_base = os.path.join('/tmp', '.keras')
206  datadir = os.path.join(datadir_base, cache_subdir)
207  if not os.path.exists(datadir):
208    os.makedirs(datadir)
209
210  if untar:
211    untar_fpath = os.path.join(datadir, fname)
212    fpath = untar_fpath + '.tar.gz'
213  else:
214    fpath = os.path.join(datadir, fname)
215
216  download = False
217  if os.path.exists(fpath):
218    # File found; verify integrity if a hash was provided.
219    if file_hash is not None:
220      if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
221        print('A local file was found, but it seems to be '
222              'incomplete or outdated because the ' + hash_algorithm +
223              ' file hash does not match the original value of ' + file_hash +
224              ' so we will re-download the data.')
225        download = True
226  else:
227    download = True
228
229  if download:
230    print('Downloading data from', origin)
231
232    class ProgressTracker(object):
233      # Maintain progbar for the lifetime of download.
234      # This design was chosen for Python 2.7 compatibility.
235      progbar = None
236
237    def dl_progress(count, block_size, total_size):
238      if ProgressTracker.progbar is None:
239        if total_size == -1:
240          total_size = None
241        ProgressTracker.progbar = Progbar(total_size)
242      else:
243        ProgressTracker.progbar.update(count * block_size)
244
245    error_msg = 'URL fetch failure on {}: {} -- {}'
246    try:
247      try:
248        urlretrieve(origin, fpath, dl_progress)
249      except HTTPError as e:
250        raise Exception(error_msg.format(origin, e.code, e.msg))
251      except URLError as e:
252        raise Exception(error_msg.format(origin, e.errno, e.reason))
253    except (Exception, KeyboardInterrupt) as e:
254      if os.path.exists(fpath):
255        os.remove(fpath)
256      raise
257    ProgressTracker.progbar = None
258
259  if untar:
260    if not os.path.exists(untar_fpath):
261      _extract_archive(fpath, datadir, archive_format='tar')
262    return untar_fpath
263
264  if extract:
265    _extract_archive(fpath, datadir, archive_format)
266
267  return fpath
268
269
270def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
271  """Calculates a file sha256 or md5 hash.
272
273  Example:
274
275  ```python
276      >>> from keras.data_utils import _hash_file
277      >>> _hash_file('/path/to/file.zip')
278      'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
279  ```
280
281  Arguments:
282      fpath: path to the file being validated
283      algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'.
284          The default 'auto' detects the hash algorithm in use.
285      chunk_size: Bytes to read at a time, important for large files.
286
287  Returns:
288      The file hash
289  """
290  if (algorithm == 'sha256') or (algorithm == 'auto' and len(hash) == 64):
291    hasher = hashlib.sha256()
292  else:
293    hasher = hashlib.md5()
294
295  with open(fpath, 'rb') as fpath_file:
296    for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
297      hasher.update(chunk)
298
299  return hasher.hexdigest()
300
301
302def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
303  """Validates a file against a sha256 or md5 hash.
304
305  Arguments:
306      fpath: path to the file being validated
307      file_hash:  The expected hash string of the file.
308          The sha256 and md5 hash algorithms are both supported.
309      algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'.
310          The default 'auto' detects the hash algorithm in use.
311      chunk_size: Bytes to read at a time, important for large files.
312
313  Returns:
314      Whether the file is valid
315  """
316  if (algorithm == 'sha256') or (algorithm == 'auto' and len(file_hash) == 64):
317    hasher = 'sha256'
318  else:
319    hasher = 'md5'
320
321  if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash):
322    return True
323  else:
324    return False
325
326
327@keras_export('keras.utils.Sequence')
328class Sequence(object):
329  """Base object for fitting to a sequence of data, such as a dataset.
330
331  Every `Sequence` must implement the `__getitem__` and the `__len__` methods.
332  If you want to modify your dataset between epochs you may implement
333  `on_epoch_end`.
334  The method `__getitem__` should return a complete batch.
335
336  Notes:
337
338  `Sequence` are a safer way to do multiprocessing. This structure guarantees
339  that the network will only train once
340   on each sample per epoch which is not the case with generators.
341
342  Examples:
343
344  ```python
345      from skimage.io import imread
346      from skimage.transform import resize
347      import numpy as np
348      import math
349
350      # Here, `x_set` is list of path to the images
351      # and `y_set` are the associated classes.
352
353      class CIFAR10Sequence(Sequence):
354
355          def __init__(self, x_set, y_set, batch_size):
356              self.x, self.y = x_set, y_set
357              self.batch_size = batch_size
358
359          def __len__(self):
360              return math.ceil(len(self.x) / self.batch_size)
361
362          def __getitem__(self, idx):
363              batch_x = self.x[idx * self.batch_size:(idx + 1) *
364              self.batch_size]
365              batch_y = self.y[idx * self.batch_size:(idx + 1) *
366              self.batch_size]
367
368              return np.array([
369                  resize(imread(file_name), (200, 200))
370                     for file_name in batch_x]), np.array(batch_y)
371  ```
372  """
373
374  @abstractmethod
375  def __getitem__(self, index):
376    """Gets batch at position `index`.
377
378    Arguments:
379        index: position of the batch in the Sequence.
380
381    Returns:
382        A batch
383    """
384    raise NotImplementedError
385
386  @abstractmethod
387  def __len__(self):
388    """Number of batch in the Sequence.
389
390    Returns:
391        The number of batches in the Sequence.
392    """
393    raise NotImplementedError
394
395  def on_epoch_end(self):
396    """Method called at the end of every epoch.
397    """
398    pass
399
400  def __iter__(self):
401    """Create a generator that iterate over the Sequence."""
402    for item in (self[i] for i in range(len(self))):
403      yield item
404
405
406def iter_sequence_infinite(seq):
407  """Iterates indefinitely over a Sequence.
408
409  Arguments:
410    seq: Sequence instance.
411
412  Yields:
413    Batches of data from the Sequence.
414  """
415  while True:
416    for item in seq:
417      yield item
418
419
420# Global variables to be shared across processes
421_SHARED_SEQUENCES = {}
422# We use a Value to provide unique id to different processes.
423_SEQUENCE_COUNTER = None
424
425
426def init_pool(seqs):
427  global _SHARED_SEQUENCES
428  _SHARED_SEQUENCES = seqs
429
430
431def get_index(uid, i):
432  """Get the value from the Sequence `uid` at index `i`.
433
434  To allow multiple Sequences to be used at the same time, we use `uid` to
435  get a specific one. A single Sequence would cause the validation to
436  overwrite the training Sequence.
437
438  Arguments:
439      uid: int, Sequence identifier
440      i: index
441
442  Returns:
443      The value at index `i`.
444  """
445  return _SHARED_SEQUENCES[uid][i]
446
447
448@keras_export('keras.utils.SequenceEnqueuer')
449class SequenceEnqueuer(object):
450  """Base class to enqueue inputs.
451
452  The task of an Enqueuer is to use parallelism to speed up preprocessing.
453  This is done with processes or threads.
454
455  Example:
456
457  ```python
458      enqueuer = SequenceEnqueuer(...)
459      enqueuer.start()
460      datas = enqueuer.get()
461      for data in datas:
462          # Use the inputs; training, evaluating, predicting.
463          # ... stop sometime.
464      enqueuer.close()
465  ```
466
467  The `enqueuer.get()` should be an infinite stream of datas.
468  """
469
470  def __init__(self, sequence,
471               use_multiprocessing=False):
472    self.sequence = sequence
473    self.use_multiprocessing = use_multiprocessing
474
475    global _SEQUENCE_COUNTER
476    if _SEQUENCE_COUNTER is None:
477      try:
478        _SEQUENCE_COUNTER = multiprocessing.Value('i', 0)
479      except OSError:
480        # In this case the OS does not allow us to use
481        # multiprocessing. We resort to an int
482        # for enqueuer indexing.
483        _SEQUENCE_COUNTER = 0
484
485    if isinstance(_SEQUENCE_COUNTER, int):
486      self.uid = _SEQUENCE_COUNTER
487      _SEQUENCE_COUNTER += 1
488    else:
489      # Doing Multiprocessing.Value += x is not process-safe.
490      with _SEQUENCE_COUNTER.get_lock():
491        self.uid = _SEQUENCE_COUNTER.value
492        _SEQUENCE_COUNTER.value += 1
493
494    self.workers = 0
495    self.executor_fn = None
496    self.queue = None
497    self.run_thread = None
498    self.stop_signal = None
499
500  def is_running(self):
501    return self.stop_signal is not None and not self.stop_signal.is_set()
502
503  def start(self, workers=1, max_queue_size=10):
504    """Starts the handler's workers.
505
506    Arguments:
507        workers: Number of workers.
508        max_queue_size: queue size
509            (when full, workers could block on `put()`)
510    """
511    if self.use_multiprocessing:
512      self.executor_fn = self._get_executor_init(workers)
513    else:
514      # We do not need the init since it's threads.
515      self.executor_fn = lambda _: ThreadPool(workers)
516    self.workers = workers
517    self.queue = queue.Queue(max_queue_size)
518    self.stop_signal = threading.Event()
519    self.run_thread = threading.Thread(target=self._run)
520    self.run_thread.daemon = True
521    self.run_thread.start()
522
523  def _send_sequence(self):
524    """Sends current Iterable to all workers."""
525    # For new processes that may spawn
526    _SHARED_SEQUENCES[self.uid] = self.sequence
527
528  def stop(self, timeout=None):
529    """Stops running threads and wait for them to exit, if necessary.
530
531    Should be called by the same thread which called `start()`.
532
533    Arguments:
534        timeout: maximum time to wait on `thread.join()`
535    """
536    self.stop_signal.set()
537    with self.queue.mutex:
538      self.queue.queue.clear()
539      self.queue.unfinished_tasks = 0
540      self.queue.not_full.notify()
541    self.run_thread.join(timeout)
542    _SHARED_SEQUENCES[self.uid] = None
543
544  @abstractmethod
545  def _run(self):
546    """Submits request to the executor and queue the `Future` objects."""
547    raise NotImplementedError
548
549  @abstractmethod
550  def _get_executor_init(self, workers):
551    """Gets the Pool initializer for multiprocessing.
552
553    Arguments:
554        workers: Number of workers.
555
556    Returns:
557        Function, a Function to initialize the pool
558    """
559    raise NotImplementedError
560
561  @abstractmethod
562  def get(self):
563    """Creates a generator to extract data from the queue.
564
565    Skip the data if it is `None`.
566    # Returns
567        Generator yielding tuples `(inputs, targets)`
568            or `(inputs, targets, sample_weights)`.
569    """
570    raise NotImplementedError
571
572
573@keras_export('keras.utils.OrderedEnqueuer')
574class OrderedEnqueuer(SequenceEnqueuer):
575  """Builds a Enqueuer from a Sequence.
576
577  Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
578
579  Arguments:
580      sequence: A `tf.keras.utils.data_utils.Sequence` object.
581      use_multiprocessing: use multiprocessing if True, otherwise threading
582      shuffle: whether to shuffle the data at the beginning of each epoch
583  """
584
585  def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
586    super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing)
587    self.shuffle = shuffle
588
589  def _get_executor_init(self, workers):
590    """Gets the Pool initializer for multiprocessing.
591
592    Arguments:
593        workers: Number of workers.
594
595    Returns:
596        Function, a Function to initialize the pool
597    """
598    def pool_fn(seqs):
599      return multiprocessing.Pool(
600          workers, initializer=init_pool_generator, initargs=(seqs, None))
601
602    return pool_fn
603
604  def _wait_queue(self):
605    """Wait for the queue to be empty."""
606    while True:
607      time.sleep(0.1)
608      if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set():
609        return
610
611  def _run(self):
612    """Submits request to the executor and queue the `Future` objects."""
613    sequence = list(range(len(self.sequence)))
614    self._send_sequence()  # Share the initial sequence
615    while True:
616      if self.shuffle:
617        random.shuffle(sequence)
618
619      with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
620        for i in sequence:
621          if self.stop_signal.is_set():
622            return
623          self.queue.put(
624              executor.apply_async(get_index, (self.uid, i)), block=True)
625
626        # Done with the current epoch, waiting for the final batches
627        self._wait_queue()
628
629        if self.stop_signal.is_set():
630          # We're done
631          return
632
633      # Call the internal on epoch end.
634      self.sequence.on_epoch_end()
635      self._send_sequence()  # Update the pool
636
637  def get(self):
638    """Creates a generator to extract data from the queue.
639
640    Skip the data if it is `None`.
641
642    Yields:
643        The next element in the queue, i.e. a tuple
644        `(inputs, targets)` or
645        `(inputs, targets, sample_weights)`.
646    """
647    try:
648      while self.is_running():
649        inputs = self.queue.get(block=True).get()
650        self.queue.task_done()
651        if inputs is not None:
652          yield inputs
653    except Exception:  # pylint: disable=broad-except
654      self.stop()
655      six.reraise(*sys.exc_info())
656
657
658def init_pool_generator(gens, random_seed=None):
659  global _SHARED_SEQUENCES
660  _SHARED_SEQUENCES = gens
661
662  if random_seed is not None:
663    ident = multiprocessing.current_process().ident
664    np.random.seed(random_seed + ident)
665
666
667def next_sample(uid):
668  """Gets the next value from the generator `uid`.
669
670  To allow multiple generators to be used at the same time, we use `uid` to
671  get a specific one. A single generator would cause the validation to
672  overwrite the training generator.
673
674  Arguments:
675      uid: int, generator identifier
676
677  Returns:
678      The next value of generator `uid`.
679  """
680  return six.next(_SHARED_SEQUENCES[uid])
681
682
683@keras_export('keras.utils.GeneratorEnqueuer')
684class GeneratorEnqueuer(SequenceEnqueuer):
685  """Builds a queue out of a data generator.
686
687  The provided generator can be finite in which case the class will throw
688  a `StopIteration` exception.
689
690  Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
691
692  Arguments:
693      generator: a generator function which yields data
694      use_multiprocessing: use multiprocessing if True, otherwise threading
695      wait_time: time to sleep in-between calls to `put()`
696      random_seed: Initial seed for workers,
697          will be incremented by one for each worker.
698  """
699
700  def __init__(self, sequence,
701               use_multiprocessing=False,
702               random_seed=None):
703    super(GeneratorEnqueuer, self).__init__(sequence, use_multiprocessing)
704    self.random_seed = random_seed
705
706  def _get_executor_init(self, workers):
707    """Gets the Pool initializer for multiprocessing.
708
709    Arguments:
710      workers: Number of works.
711
712    Returns:
713        A Function to initialize the pool
714    """
715    def pool_fn(seqs):
716      return multiprocessing.Pool(workers,
717                                  initializer=init_pool_generator,
718                                  initargs=(seqs, self.random_seed))
719    return pool_fn
720
721  def _run(self):
722    """Submits request to the executor and queue the `Future` objects."""
723    self._send_sequence()  # Share the initial generator
724    with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
725      while True:
726        if self.stop_signal.is_set():
727          return
728        self.queue.put(
729            executor.apply_async(next_sample, (self.uid,)), block=True)
730
731  def get(self):
732    """Creates a generator to extract data from the queue.
733
734    Skip the data if it is `None`.
735
736    Yields:
737        The next element in the queue, i.e. a tuple
738        `(inputs, targets)` or
739        `(inputs, targets, sample_weights)`.
740    """
741    try:
742      while self.is_running():
743        inputs = self.queue.get(block=True).get()
744        self.queue.task_done()
745        if inputs is not None:
746          yield inputs
747    except StopIteration:
748      # Special case for finite generators
749      last_ones = []
750      while self.queue.qsize() > 0:
751        last_ones.append(self.queue.get(block=True))
752      # Wait for them to complete
753      for f in last_ones:
754        f.wait()
755      # Keep the good ones
756      last_ones = [future.get() for future in last_ones if future.successful()]
757      for inputs in last_ones:
758        if inputs is not None:
759          yield inputs
760    except Exception as e:  # pylint: disable=broad-except
761      self.stop()
762      if 'generator already executing' in str(e):
763        raise RuntimeError(
764            'Your generator is NOT thread-safe. '
765            'Keras requires a thread-safe generator when '
766            '`use_multiprocessing=False, workers > 1`. ')
767      six.reraise(*sys.exc_info())
768