• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16# pylint: disable=g-import-not-at-top
17"""Utilities for file download and caching."""
18
19from abc import abstractmethod
20from contextlib import closing
21import functools
22import hashlib
23import multiprocessing
24import multiprocessing.dummy
25import os
26import queue
27import random
28import shutil
29import sys  # pylint: disable=unused-import
30import tarfile
31import threading
32import time
33import typing
34import urllib
35import weakref
36import zipfile
37
38import numpy as np
39
40from tensorflow.python.framework import ops
41from six.moves.urllib.request import urlopen
42from tensorflow.python.keras.utils import tf_inspect
43from tensorflow.python.keras.utils.generic_utils import Progbar
44from tensorflow.python.keras.utils.io_utils import path_to_string
45from tensorflow.python.util.tf_export import keras_export
46
47# Required to support google internal urlretrieve
48if sys.version_info[0] == 2:
49
50  def urlretrieve(url, filename, reporthook=None, data=None):
51    """Replacement for `urlretrieve` for Python 2.
52
53    Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy
54    `urllib` module, known to have issues with proxy management.
55
56    Args:
57        url: url to retrieve.
58        filename: where to store the retrieved data locally.
59        reporthook: a hook function that will be called once on establishment of
60          the network connection and once after each block read thereafter. The
61          hook will be passed three arguments; a count of blocks transferred so
62          far, a block size in bytes, and the total size of the file.
63        data: `data` argument passed to `urlopen`.
64    """
65
66    def chunk_read(response, chunk_size=8192, reporthook=None):
67      content_type = response.info().get('Content-Length')
68      total_size = -1
69      if content_type is not None:
70        total_size = int(content_type.strip())
71      count = 0
72      while True:
73        chunk = response.read(chunk_size)
74        count += 1
75        if reporthook is not None:
76          reporthook(count, chunk_size, total_size)
77        if chunk:
78          yield chunk
79        else:
80          break
81
82    response = urlopen(url, data)
83    with open(filename, 'wb') as fd:
84      for chunk in chunk_read(response, reporthook=reporthook):
85        fd.write(chunk)
86else:
87  from urllib.request import urlretrieve  # pylint: disable=g-importing-member
88
89
90def is_generator_or_sequence(x):
91  """Check if `x` is a Keras generator type."""
92  builtin_iterators = (str, list, tuple, dict, set, frozenset)
93  if isinstance(x, (ops.Tensor, np.ndarray) + builtin_iterators):
94    return False
95  return (tf_inspect.isgenerator(x) or
96          isinstance(x, Sequence) or
97          isinstance(x, typing.Iterator))
98
99
100def _extract_archive(file_path, path='.', archive_format='auto'):
101  """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
102
103  Args:
104      file_path: path to the archive file
105      path: path to extract the archive file
106      archive_format: Archive format to try for extracting the file.
107          Options are 'auto', 'tar', 'zip', and None.
108          'tar' includes tar, tar.gz, and tar.bz files.
109          The default 'auto' is ['tar', 'zip'].
110          None or an empty list will return no matches found.
111
112  Returns:
113      True if a match was found and an archive extraction was completed,
114      False otherwise.
115  """
116  if archive_format is None:
117    return False
118  if archive_format == 'auto':
119    archive_format = ['tar', 'zip']
120  if isinstance(archive_format, str):
121    archive_format = [archive_format]
122
123  file_path = path_to_string(file_path)
124  path = path_to_string(path)
125
126  for archive_type in archive_format:
127    if archive_type == 'tar':
128      open_fn = tarfile.open
129      is_match_fn = tarfile.is_tarfile
130    if archive_type == 'zip':
131      open_fn = zipfile.ZipFile
132      is_match_fn = zipfile.is_zipfile
133
134    if is_match_fn(file_path):
135      with open_fn(file_path) as archive:
136        try:
137          archive.extractall(path)
138        except (tarfile.TarError, RuntimeError, KeyboardInterrupt):
139          if os.path.exists(path):
140            if os.path.isfile(path):
141              os.remove(path)
142            else:
143              shutil.rmtree(path)
144          raise
145      return True
146  return False
147
148
149@keras_export('keras.utils.get_file')
150def get_file(fname,
151             origin,
152             untar=False,
153             md5_hash=None,
154             file_hash=None,
155             cache_subdir='datasets',
156             hash_algorithm='auto',
157             extract=False,
158             archive_format='auto',
159             cache_dir=None):
160  """Downloads a file from a URL if it not already in the cache.
161
162  By default the file at the url `origin` is downloaded to the
163  cache_dir `~/.keras`, placed in the cache_subdir `datasets`,
164  and given the filename `fname`. The final location of a file
165  `example.txt` would therefore be `~/.keras/datasets/example.txt`.
166
167  Files in tar, tar.gz, tar.bz, and zip formats can also be extracted.
168  Passing a hash will verify the file after download. The command line
169  programs `shasum` and `sha256sum` can compute the hash.
170
171  Example:
172
173  ```python
174  path_to_downloaded_file = tf.keras.utils.get_file(
175      "flower_photos",
176      "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
177      untar=True)
178  ```
179
180  Args:
181      fname: Name of the file. If an absolute path `/path/to/file.txt` is
182          specified the file will be saved at that location.
183      origin: Original URL of the file.
184      untar: Deprecated in favor of `extract` argument.
185          boolean, whether the file should be decompressed
186      md5_hash: Deprecated in favor of `file_hash` argument.
187          md5 hash of the file for verification
188      file_hash: The expected hash string of the file after download.
189          The sha256 and md5 hash algorithms are both supported.
190      cache_subdir: Subdirectory under the Keras cache dir where the file is
191          saved. If an absolute path `/path/to/folder` is
192          specified the file will be saved at that location.
193      hash_algorithm: Select the hash algorithm to verify the file.
194          options are `'md5'`, `'sha256'`, and `'auto'`.
195          The default 'auto' detects the hash algorithm in use.
196      extract: True tries extracting the file as an Archive, like tar or zip.
197      archive_format: Archive format to try for extracting the file.
198          Options are `'auto'`, `'tar'`, `'zip'`, and `None`.
199          `'tar'` includes tar, tar.gz, and tar.bz files.
200          The default `'auto'` corresponds to `['tar', 'zip']`.
201          None or an empty list will return no matches found.
202      cache_dir: Location to store cached files, when None it
203          defaults to the default directory `~/.keras/`.
204
205  Returns:
206      Path to the downloaded file
207  """
208  if cache_dir is None:
209    cache_dir = os.path.join(os.path.expanduser('~'), '.keras')
210  if md5_hash is not None and file_hash is None:
211    file_hash = md5_hash
212    hash_algorithm = 'md5'
213  datadir_base = os.path.expanduser(cache_dir)
214  if not os.access(datadir_base, os.W_OK):
215    datadir_base = os.path.join('/tmp', '.keras')
216  datadir = os.path.join(datadir_base, cache_subdir)
217  _makedirs_exist_ok(datadir)
218
219  fname = path_to_string(fname)
220
221  if untar:
222    untar_fpath = os.path.join(datadir, fname)
223    fpath = untar_fpath + '.tar.gz'
224  else:
225    fpath = os.path.join(datadir, fname)
226
227  download = False
228  if os.path.exists(fpath):
229    # File found; verify integrity if a hash was provided.
230    if file_hash is not None:
231      if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
232        print('A local file was found, but it seems to be '
233              'incomplete or outdated because the ' + hash_algorithm +
234              ' file hash does not match the original value of ' + file_hash +
235              ' so we will re-download the data.')
236        download = True
237  else:
238    download = True
239
240  if download:
241    print('Downloading data from', origin)
242
243    class ProgressTracker(object):
244      # Maintain progbar for the lifetime of download.
245      # This design was chosen for Python 2.7 compatibility.
246      progbar = None
247
248    def dl_progress(count, block_size, total_size):
249      if ProgressTracker.progbar is None:
250        if total_size == -1:
251          total_size = None
252        ProgressTracker.progbar = Progbar(total_size)
253      else:
254        ProgressTracker.progbar.update(count * block_size)
255
256    error_msg = 'URL fetch failure on {}: {} -- {}'
257    try:
258      try:
259        urlretrieve(origin, fpath, dl_progress)
260      except urllib.error.HTTPError as e:
261        raise Exception(error_msg.format(origin, e.code, e.msg))
262      except urllib.error.URLError as e:
263        raise Exception(error_msg.format(origin, e.errno, e.reason))
264    except (Exception, KeyboardInterrupt) as e:
265      if os.path.exists(fpath):
266        os.remove(fpath)
267      raise
268    ProgressTracker.progbar = None
269
270  if untar:
271    if not os.path.exists(untar_fpath):
272      _extract_archive(fpath, datadir, archive_format='tar')
273    return untar_fpath
274
275  if extract:
276    _extract_archive(fpath, datadir, archive_format)
277
278  return fpath
279
280
281def _makedirs_exist_ok(datadir):
282  os.makedirs(datadir, exist_ok=True)  # pylint: disable=unexpected-keyword-arg
283
284
285def _resolve_hasher(algorithm, file_hash=None):
286  """Returns hash algorithm as hashlib function."""
287  if algorithm == 'sha256':
288    return hashlib.sha256()
289
290  if algorithm == 'auto' and file_hash is not None and len(file_hash) == 64:
291    return hashlib.sha256()
292
293  # This is used only for legacy purposes.
294  return hashlib.md5()
295
296
297def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
298  """Calculates a file sha256 or md5 hash.
299
300  Example:
301
302  ```python
303  _hash_file('/path/to/file.zip')
304  'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
305  ```
306
307  Args:
308      fpath: path to the file being validated
309      algorithm: hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`.
310          The default `'auto'` detects the hash algorithm in use.
311      chunk_size: Bytes to read at a time, important for large files.
312
313  Returns:
314      The file hash
315  """
316  if isinstance(algorithm, str):
317    hasher = _resolve_hasher(algorithm)
318  else:
319    hasher = algorithm
320
321  with open(fpath, 'rb') as fpath_file:
322    for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
323      hasher.update(chunk)
324
325  return hasher.hexdigest()
326
327
328def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
329  """Validates a file against a sha256 or md5 hash.
330
331  Args:
332      fpath: path to the file being validated
333      file_hash:  The expected hash string of the file.
334          The sha256 and md5 hash algorithms are both supported.
335      algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'.
336          The default 'auto' detects the hash algorithm in use.
337      chunk_size: Bytes to read at a time, important for large files.
338
339  Returns:
340      Whether the file is valid
341  """
342  hasher = _resolve_hasher(algorithm, file_hash)
343
344  if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash):
345    return True
346  else:
347    return False
348
349
350class ThreadsafeIter(object):
351  """Wrap an iterator with a lock and propagate exceptions to all threads."""
352
353  def __init__(self, it):
354    self.it = it
355    self.lock = threading.Lock()
356
357    # After a generator throws an exception all subsequent next() calls raise a
358    # StopIteration Exception. This, however, presents an issue when mixing
359    # generators and threading because it means the order of retrieval need not
360    # match the order in which the generator was called. This can make it appear
361    # that a generator exited normally when in fact the terminating exception is
362    # just in a different thread. In order to provide thread safety, once
363    # self.it has thrown an exception we continue to throw the same exception.
364    self._exception = None
365
366  def __iter__(self):
367    return self
368
369  def next(self):
370    return self.__next__()
371
372  def __next__(self):
373    with self.lock:
374      if self._exception:
375        raise self._exception  # pylint: disable=raising-bad-type
376
377      try:
378        return next(self.it)
379      except Exception as e:
380        self._exception = e
381        raise
382
383
384def threadsafe_generator(f):
385
386  @functools.wraps(f)
387  def g(*a, **kw):
388    return ThreadsafeIter(f(*a, **kw))
389
390  return g
391
392
393@keras_export('keras.utils.Sequence')
394class Sequence(object):
395  """Base object for fitting to a sequence of data, such as a dataset.
396
397  Every `Sequence` must implement the `__getitem__` and the `__len__` methods.
398  If you want to modify your dataset between epochs you may implement
399  `on_epoch_end`.
400  The method `__getitem__` should return a complete batch.
401
402  Notes:
403
404  `Sequence` are a safer way to do multiprocessing. This structure guarantees
405  that the network will only train once
406   on each sample per epoch which is not the case with generators.
407
408  Examples:
409
410  ```python
411  from skimage.io import imread
412  from skimage.transform import resize
413  import numpy as np
414  import math
415
416  # Here, `x_set` is list of path to the images
417  # and `y_set` are the associated classes.
418
419  class CIFAR10Sequence(Sequence):
420
421      def __init__(self, x_set, y_set, batch_size):
422          self.x, self.y = x_set, y_set
423          self.batch_size = batch_size
424
425      def __len__(self):
426          return math.ceil(len(self.x) / self.batch_size)
427
428      def __getitem__(self, idx):
429          batch_x = self.x[idx * self.batch_size:(idx + 1) *
430          self.batch_size]
431          batch_y = self.y[idx * self.batch_size:(idx + 1) *
432          self.batch_size]
433
434          return np.array([
435              resize(imread(file_name), (200, 200))
436                 for file_name in batch_x]), np.array(batch_y)
437  ```
438  """
439
440  @abstractmethod
441  def __getitem__(self, index):
442    """Gets batch at position `index`.
443
444    Args:
445        index: position of the batch in the Sequence.
446
447    Returns:
448        A batch
449    """
450    raise NotImplementedError
451
452  @abstractmethod
453  def __len__(self):
454    """Number of batch in the Sequence.
455
456    Returns:
457        The number of batches in the Sequence.
458    """
459    raise NotImplementedError
460
461  def on_epoch_end(self):
462    """Method called at the end of every epoch.
463    """
464    pass
465
466  def __iter__(self):
467    """Create a generator that iterate over the Sequence."""
468    for item in (self[i] for i in range(len(self))):
469      yield item
470
471
472def iter_sequence_infinite(seq):
473  """Iterates indefinitely over a Sequence.
474
475  Args:
476    seq: `Sequence` instance.
477
478  Yields:
479    Batches of data from the `Sequence`.
480  """
481  while True:
482    for item in seq:
483      yield item
484
485
486# Global variables to be shared across processes
487_SHARED_SEQUENCES = {}
488# We use a Value to provide unique id to different processes.
489_SEQUENCE_COUNTER = None
490
491
492# Because multiprocessing pools are inherently unsafe, starting from a clean
493# state can be essential to avoiding deadlocks. In order to accomplish this, we
494# need to be able to check on the status of Pools that we create.
495_DATA_POOLS = weakref.WeakSet()
496_WORKER_ID_QUEUE = None  # Only created if needed.
497_WORKER_IDS = set()
498_FORCE_THREADPOOL = False
499_FORCE_THREADPOOL_LOCK = threading.RLock()
500
501
502def dont_use_multiprocessing_pool(f):
503  @functools.wraps(f)
504  def wrapped(*args, **kwargs):
505    with _FORCE_THREADPOOL_LOCK:
506      global _FORCE_THREADPOOL
507      old_force_threadpool, _FORCE_THREADPOOL = _FORCE_THREADPOOL, True
508      out = f(*args, **kwargs)
509      _FORCE_THREADPOOL = old_force_threadpool
510      return out
511  return wrapped
512
513
514def get_pool_class(use_multiprocessing):
515  global _FORCE_THREADPOOL
516  if not use_multiprocessing or _FORCE_THREADPOOL:
517    return multiprocessing.dummy.Pool  # ThreadPool
518  return multiprocessing.Pool
519
520
521def get_worker_id_queue():
522  """Lazily create the queue to track worker ids."""
523  global _WORKER_ID_QUEUE
524  if _WORKER_ID_QUEUE is None:
525    _WORKER_ID_QUEUE = multiprocessing.Queue()
526  return _WORKER_ID_QUEUE
527
528
529def init_pool(seqs):
530  global _SHARED_SEQUENCES
531  _SHARED_SEQUENCES = seqs
532
533
534def get_index(uid, i):
535  """Get the value from the Sequence `uid` at index `i`.
536
537  To allow multiple Sequences to be used at the same time, we use `uid` to
538  get a specific one. A single Sequence would cause the validation to
539  overwrite the training Sequence.
540
541  Args:
542      uid: int, Sequence identifier
543      i: index
544
545  Returns:
546      The value at index `i`.
547  """
548  return _SHARED_SEQUENCES[uid][i]
549
550
551@keras_export('keras.utils.SequenceEnqueuer')
552class SequenceEnqueuer(object):
553  """Base class to enqueue inputs.
554
555  The task of an Enqueuer is to use parallelism to speed up preprocessing.
556  This is done with processes or threads.
557
558  Example:
559
560  ```python
561      enqueuer = SequenceEnqueuer(...)
562      enqueuer.start()
563      datas = enqueuer.get()
564      for data in datas:
565          # Use the inputs; training, evaluating, predicting.
566          # ... stop sometime.
567      enqueuer.stop()
568  ```
569
570  The `enqueuer.get()` should be an infinite stream of datas.
571  """
572
573  def __init__(self, sequence,
574               use_multiprocessing=False):
575    self.sequence = sequence
576    self.use_multiprocessing = use_multiprocessing
577
578    global _SEQUENCE_COUNTER
579    if _SEQUENCE_COUNTER is None:
580      try:
581        _SEQUENCE_COUNTER = multiprocessing.Value('i', 0)
582      except OSError:
583        # In this case the OS does not allow us to use
584        # multiprocessing. We resort to an int
585        # for enqueuer indexing.
586        _SEQUENCE_COUNTER = 0
587
588    if isinstance(_SEQUENCE_COUNTER, int):
589      self.uid = _SEQUENCE_COUNTER
590      _SEQUENCE_COUNTER += 1
591    else:
592      # Doing Multiprocessing.Value += x is not process-safe.
593      with _SEQUENCE_COUNTER.get_lock():
594        self.uid = _SEQUENCE_COUNTER.value
595        _SEQUENCE_COUNTER.value += 1
596
597    self.workers = 0
598    self.executor_fn = None
599    self.queue = None
600    self.run_thread = None
601    self.stop_signal = None
602
603  def is_running(self):
604    return self.stop_signal is not None and not self.stop_signal.is_set()
605
606  def start(self, workers=1, max_queue_size=10):
607    """Starts the handler's workers.
608
609    Args:
610        workers: Number of workers.
611        max_queue_size: queue size
612            (when full, workers could block on `put()`)
613    """
614    if self.use_multiprocessing:
615      self.executor_fn = self._get_executor_init(workers)
616    else:
617      # We do not need the init since it's threads.
618      self.executor_fn = lambda _: get_pool_class(False)(workers)
619    self.workers = workers
620    self.queue = queue.Queue(max_queue_size)
621    self.stop_signal = threading.Event()
622    self.run_thread = threading.Thread(target=self._run)
623    self.run_thread.daemon = True
624    self.run_thread.start()
625
626  def _send_sequence(self):
627    """Sends current Iterable to all workers."""
628    # For new processes that may spawn
629    _SHARED_SEQUENCES[self.uid] = self.sequence
630
631  def stop(self, timeout=None):
632    """Stops running threads and wait for them to exit, if necessary.
633
634    Should be called by the same thread which called `start()`.
635
636    Args:
637        timeout: maximum time to wait on `thread.join()`
638    """
639    self.stop_signal.set()
640    with self.queue.mutex:
641      self.queue.queue.clear()
642      self.queue.unfinished_tasks = 0
643      self.queue.not_full.notify()
644    self.run_thread.join(timeout)
645    _SHARED_SEQUENCES[self.uid] = None
646
647  def __del__(self):
648    if self.is_running():
649      self.stop()
650
651  @abstractmethod
652  def _run(self):
653    """Submits request to the executor and queue the `Future` objects."""
654    raise NotImplementedError
655
656  @abstractmethod
657  def _get_executor_init(self, workers):
658    """Gets the Pool initializer for multiprocessing.
659
660    Args:
661        workers: Number of workers.
662
663    Returns:
664        Function, a Function to initialize the pool
665    """
666    raise NotImplementedError
667
668  @abstractmethod
669  def get(self):
670    """Creates a generator to extract data from the queue.
671
672    Skip the data if it is `None`.
673    # Returns
674        Generator yielding tuples `(inputs, targets)`
675            or `(inputs, targets, sample_weights)`.
676    """
677    raise NotImplementedError
678
679
680@keras_export('keras.utils.OrderedEnqueuer')
681class OrderedEnqueuer(SequenceEnqueuer):
682  """Builds a Enqueuer from a Sequence.
683
684  Args:
685      sequence: A `tf.keras.utils.data_utils.Sequence` object.
686      use_multiprocessing: use multiprocessing if True, otherwise threading
687      shuffle: whether to shuffle the data at the beginning of each epoch
688  """
689
690  def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
691    super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing)
692    self.shuffle = shuffle
693
694  def _get_executor_init(self, workers):
695    """Gets the Pool initializer for multiprocessing.
696
697    Args:
698        workers: Number of workers.
699
700    Returns:
701        Function, a Function to initialize the pool
702    """
703    def pool_fn(seqs):
704      pool = get_pool_class(True)(
705          workers, initializer=init_pool_generator,
706          initargs=(seqs, None, get_worker_id_queue()))
707      _DATA_POOLS.add(pool)
708      return pool
709
710    return pool_fn
711
712  def _wait_queue(self):
713    """Wait for the queue to be empty."""
714    while True:
715      time.sleep(0.1)
716      if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set():
717        return
718
719  def _run(self):
720    """Submits request to the executor and queue the `Future` objects."""
721    sequence = list(range(len(self.sequence)))
722    self._send_sequence()  # Share the initial sequence
723    while True:
724      if self.shuffle:
725        random.shuffle(sequence)
726
727      with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
728        for i in sequence:
729          if self.stop_signal.is_set():
730            return
731
732          self.queue.put(
733              executor.apply_async(get_index, (self.uid, i)), block=True)
734
735        # Done with the current epoch, waiting for the final batches
736        self._wait_queue()
737
738        if self.stop_signal.is_set():
739          # We're done
740          return
741
742      # Call the internal on epoch end.
743      self.sequence.on_epoch_end()
744      self._send_sequence()  # Update the pool
745
746  def get(self):
747    """Creates a generator to extract data from the queue.
748
749    Skip the data if it is `None`.
750
751    Yields:
752        The next element in the queue, i.e. a tuple
753        `(inputs, targets)` or
754        `(inputs, targets, sample_weights)`.
755    """
756    while self.is_running():
757      try:
758        inputs = self.queue.get(block=True, timeout=5).get()
759        if self.is_running():
760          self.queue.task_done()
761        if inputs is not None:
762          yield inputs
763      except queue.Empty:
764        pass
765      except Exception as e:  # pylint: disable=broad-except
766        self.stop()
767        raise e
768
769
770def init_pool_generator(gens, random_seed=None, id_queue=None):
771  """Initializer function for pool workers.
772
773  Args:
774    gens: State which should be made available to worker processes.
775    random_seed: An optional value with which to seed child processes.
776    id_queue: A multiprocessing Queue of worker ids. This is used to indicate
777      that a worker process was created by Keras and can be terminated using
778      the cleanup_all_keras_forkpools utility.
779  """
780  global _SHARED_SEQUENCES
781  _SHARED_SEQUENCES = gens
782
783  worker_proc = multiprocessing.current_process()
784
785  # name isn't used for anything, but setting a more descriptive name is helpful
786  # when diagnosing orphaned processes.
787  worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name)
788
789  if random_seed is not None:
790    np.random.seed(random_seed + worker_proc.ident)
791
792  if id_queue is not None:
793    # If a worker dies during init, the pool will just create a replacement.
794    id_queue.put(worker_proc.ident, block=True, timeout=0.1)
795
796
797def next_sample(uid):
798  """Gets the next value from the generator `uid`.
799
800  To allow multiple generators to be used at the same time, we use `uid` to
801  get a specific one. A single generator would cause the validation to
802  overwrite the training generator.
803
804  Args:
805      uid: int, generator identifier
806
807  Returns:
808      The next value of generator `uid`.
809  """
810  return next(_SHARED_SEQUENCES[uid])
811
812
813@keras_export('keras.utils.GeneratorEnqueuer')
814class GeneratorEnqueuer(SequenceEnqueuer):
815  """Builds a queue out of a data generator.
816
817  The provided generator can be finite in which case the class will throw
818  a `StopIteration` exception.
819
820  Args:
821      generator: a generator function which yields data
822      use_multiprocessing: use multiprocessing if True, otherwise threading
823      random_seed: Initial seed for workers,
824          will be incremented by one for each worker.
825  """
826
827  def __init__(self, generator,
828               use_multiprocessing=False,
829               random_seed=None):
830    super(GeneratorEnqueuer, self).__init__(generator, use_multiprocessing)
831    self.random_seed = random_seed
832
833  def _get_executor_init(self, workers):
834    """Gets the Pool initializer for multiprocessing.
835
836    Args:
837      workers: Number of works.
838
839    Returns:
840        A Function to initialize the pool
841    """
842    def pool_fn(seqs):
843      pool = get_pool_class(True)(
844          workers, initializer=init_pool_generator,
845          initargs=(seqs, self.random_seed, get_worker_id_queue()))
846      _DATA_POOLS.add(pool)
847      return pool
848    return pool_fn
849
850  def _run(self):
851    """Submits request to the executor and queue the `Future` objects."""
852    self._send_sequence()  # Share the initial generator
853    with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
854      while True:
855        if self.stop_signal.is_set():
856          return
857
858        self.queue.put(
859            executor.apply_async(next_sample, (self.uid,)), block=True)
860
861  def get(self):
862    """Creates a generator to extract data from the queue.
863
864    Skip the data if it is `None`.
865
866    Yields:
867        The next element in the queue, i.e. a tuple
868        `(inputs, targets)` or
869        `(inputs, targets, sample_weights)`.
870    """
871    try:
872      while self.is_running():
873        inputs = self.queue.get(block=True).get()
874        self.queue.task_done()
875        if inputs is not None:
876          yield inputs
877    except StopIteration:
878      # Special case for finite generators
879      last_ones = []
880      while self.queue.qsize() > 0:
881        last_ones.append(self.queue.get(block=True))
882      # Wait for them to complete
883      for f in last_ones:
884        f.wait()
885      # Keep the good ones
886      last_ones = [future.get() for future in last_ones if future.successful()]
887      for inputs in last_ones:
888        if inputs is not None:
889          yield inputs
890    except Exception as e:  # pylint: disable=broad-except
891      self.stop()
892      if 'generator already executing' in str(e):
893        raise RuntimeError(
894            'Your generator is NOT thread-safe. '
895            'Keras requires a thread-safe generator when '
896            '`use_multiprocessing=False, workers > 1`. ')
897      raise e
898