1# Lint as python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16# pylint: disable=g-import-not-at-top 17"""Utilities for file download and caching.""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from abc import abstractmethod 23from contextlib import closing 24import errno 25import functools 26import hashlib 27import multiprocessing 28import multiprocessing.dummy 29import os 30import random 31import shutil 32import sys 33import tarfile 34import threading 35import time 36import weakref 37import zipfile 38 39import numpy as np 40import six 41from six.moves.urllib.error import HTTPError 42from six.moves.urllib.error import URLError 43 44from tensorflow.python.framework import ops 45from six.moves.urllib.request import urlopen 46from tensorflow.python.keras.utils import tf_inspect 47from tensorflow.python.keras.utils.generic_utils import Progbar 48from tensorflow.python.keras.utils.io_utils import path_to_string 49from tensorflow.python.util.tf_export import keras_export 50 51 52try: 53 import queue 54except ImportError: 55 import Queue as queue 56 57try: 58 import typing 59 is_iterator = lambda x: isinstance(x, typing.Iterator) 60except ImportError: 61 # Python2 uses next, and Python3 should have typing so __next__ is not needed. 62 is_iterator = lambda x: hasattr(x, '__iter__') and hasattr(x, 'next') 63 64 65if sys.version_info[0] == 2: 66 67 def urlretrieve(url, filename, reporthook=None, data=None): 68 """Replacement for `urlretrieve` for Python 2. 69 70 Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy 71 `urllib` module, known to have issues with proxy management. 72 73 Args: 74 url: url to retrieve. 75 filename: where to store the retrieved data locally. 76 reporthook: a hook function that will be called once on establishment of 77 the network connection and once after each block read thereafter. The 78 hook will be passed three arguments; a count of blocks transferred so 79 far, a block size in bytes, and the total size of the file. 80 data: `data` argument passed to `urlopen`. 81 """ 82 83 def chunk_read(response, chunk_size=8192, reporthook=None): 84 content_type = response.info().get('Content-Length') 85 total_size = -1 86 if content_type is not None: 87 total_size = int(content_type.strip()) 88 count = 0 89 while True: 90 chunk = response.read(chunk_size) 91 count += 1 92 if reporthook is not None: 93 reporthook(count, chunk_size, total_size) 94 if chunk: 95 yield chunk 96 else: 97 break 98 99 response = urlopen(url, data) 100 with open(filename, 'wb') as fd: 101 for chunk in chunk_read(response, reporthook=reporthook): 102 fd.write(chunk) 103else: 104 from six.moves.urllib.request import urlretrieve 105 106 107def is_generator_or_sequence(x): 108 """Check if `x` is a Keras generator type.""" 109 builtin_iterators = (str, list, tuple, dict, set, frozenset) 110 if isinstance(x, (ops.Tensor, np.ndarray) + builtin_iterators): 111 return False 112 return tf_inspect.isgenerator(x) or isinstance(x, Sequence) or is_iterator(x) 113 114 115def _extract_archive(file_path, path='.', archive_format='auto'): 116 """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats. 117 118 Args: 119 file_path: path to the archive file 120 path: path to extract the archive file 121 archive_format: Archive format to try for extracting the file. 122 Options are 'auto', 'tar', 'zip', and None. 123 'tar' includes tar, tar.gz, and tar.bz files. 124 The default 'auto' is ['tar', 'zip']. 125 None or an empty list will return no matches found. 126 127 Returns: 128 True if a match was found and an archive extraction was completed, 129 False otherwise. 130 """ 131 if archive_format is None: 132 return False 133 if archive_format == 'auto': 134 archive_format = ['tar', 'zip'] 135 if isinstance(archive_format, six.string_types): 136 archive_format = [archive_format] 137 138 file_path = path_to_string(file_path) 139 path = path_to_string(path) 140 141 for archive_type in archive_format: 142 if archive_type == 'tar': 143 open_fn = tarfile.open 144 is_match_fn = tarfile.is_tarfile 145 if archive_type == 'zip': 146 open_fn = zipfile.ZipFile 147 is_match_fn = zipfile.is_zipfile 148 149 if is_match_fn(file_path): 150 with open_fn(file_path) as archive: 151 try: 152 archive.extractall(path) 153 except (tarfile.TarError, RuntimeError, KeyboardInterrupt): 154 if os.path.exists(path): 155 if os.path.isfile(path): 156 os.remove(path) 157 else: 158 shutil.rmtree(path) 159 raise 160 return True 161 return False 162 163 164@keras_export('keras.utils.get_file') 165def get_file(fname, 166 origin, 167 untar=False, 168 md5_hash=None, 169 file_hash=None, 170 cache_subdir='datasets', 171 hash_algorithm='auto', 172 extract=False, 173 archive_format='auto', 174 cache_dir=None): 175 """Downloads a file from a URL if it not already in the cache. 176 177 By default the file at the url `origin` is downloaded to the 178 cache_dir `~/.keras`, placed in the cache_subdir `datasets`, 179 and given the filename `fname`. The final location of a file 180 `example.txt` would therefore be `~/.keras/datasets/example.txt`. 181 182 Files in tar, tar.gz, tar.bz, and zip formats can also be extracted. 183 Passing a hash will verify the file after download. The command line 184 programs `shasum` and `sha256sum` can compute the hash. 185 186 Example: 187 188 ```python 189 path_to_downloaded_file = tf.keras.utils.get_file( 190 "flower_photos", 191 "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", 192 untar=True) 193 ``` 194 195 Args: 196 fname: Name of the file. If an absolute path `/path/to/file.txt` is 197 specified the file will be saved at that location. 198 origin: Original URL of the file. 199 untar: Deprecated in favor of `extract` argument. 200 boolean, whether the file should be decompressed 201 md5_hash: Deprecated in favor of `file_hash` argument. 202 md5 hash of the file for verification 203 file_hash: The expected hash string of the file after download. 204 The sha256 and md5 hash algorithms are both supported. 205 cache_subdir: Subdirectory under the Keras cache dir where the file is 206 saved. If an absolute path `/path/to/folder` is 207 specified the file will be saved at that location. 208 hash_algorithm: Select the hash algorithm to verify the file. 209 options are `'md5'`, `'sha256'`, and `'auto'`. 210 The default 'auto' detects the hash algorithm in use. 211 extract: True tries extracting the file as an Archive, like tar or zip. 212 archive_format: Archive format to try for extracting the file. 213 Options are `'auto'`, `'tar'`, `'zip'`, and `None`. 214 `'tar'` includes tar, tar.gz, and tar.bz files. 215 The default `'auto'` corresponds to `['tar', 'zip']`. 216 None or an empty list will return no matches found. 217 cache_dir: Location to store cached files, when None it 218 defaults to the default directory `~/.keras/`. 219 220 Returns: 221 Path to the downloaded file 222 """ 223 if cache_dir is None: 224 cache_dir = os.path.join(os.path.expanduser('~'), '.keras') 225 if md5_hash is not None and file_hash is None: 226 file_hash = md5_hash 227 hash_algorithm = 'md5' 228 datadir_base = os.path.expanduser(cache_dir) 229 if not os.access(datadir_base, os.W_OK): 230 datadir_base = os.path.join('/tmp', '.keras') 231 datadir = os.path.join(datadir_base, cache_subdir) 232 _makedirs_exist_ok(datadir) 233 234 fname = path_to_string(fname) 235 236 if untar: 237 untar_fpath = os.path.join(datadir, fname) 238 fpath = untar_fpath + '.tar.gz' 239 else: 240 fpath = os.path.join(datadir, fname) 241 242 download = False 243 if os.path.exists(fpath): 244 # File found; verify integrity if a hash was provided. 245 if file_hash is not None: 246 if not validate_file(fpath, file_hash, algorithm=hash_algorithm): 247 print('A local file was found, but it seems to be ' 248 'incomplete or outdated because the ' + hash_algorithm + 249 ' file hash does not match the original value of ' + file_hash + 250 ' so we will re-download the data.') 251 download = True 252 else: 253 download = True 254 255 if download: 256 print('Downloading data from', origin) 257 258 class ProgressTracker(object): 259 # Maintain progbar for the lifetime of download. 260 # This design was chosen for Python 2.7 compatibility. 261 progbar = None 262 263 def dl_progress(count, block_size, total_size): 264 if ProgressTracker.progbar is None: 265 if total_size == -1: 266 total_size = None 267 ProgressTracker.progbar = Progbar(total_size) 268 else: 269 ProgressTracker.progbar.update(count * block_size) 270 271 error_msg = 'URL fetch failure on {}: {} -- {}' 272 try: 273 try: 274 urlretrieve(origin, fpath, dl_progress) 275 except HTTPError as e: 276 raise Exception(error_msg.format(origin, e.code, e.msg)) 277 except URLError as e: 278 raise Exception(error_msg.format(origin, e.errno, e.reason)) 279 except (Exception, KeyboardInterrupt) as e: 280 if os.path.exists(fpath): 281 os.remove(fpath) 282 raise 283 ProgressTracker.progbar = None 284 285 if untar: 286 if not os.path.exists(untar_fpath): 287 _extract_archive(fpath, datadir, archive_format='tar') 288 return untar_fpath 289 290 if extract: 291 _extract_archive(fpath, datadir, archive_format) 292 293 return fpath 294 295 296def _makedirs_exist_ok(datadir): 297 if six.PY2: 298 # Python 2 doesn't have the exist_ok arg, so we try-except here. 299 try: 300 os.makedirs(datadir) 301 except OSError as e: 302 if e.errno != errno.EEXIST: 303 raise 304 else: 305 os.makedirs(datadir, exist_ok=True) # pylint: disable=unexpected-keyword-arg 306 307 308def _hash_file(fpath, algorithm='sha256', chunk_size=65535): 309 """Calculates a file sha256 or md5 hash. 310 311 Example: 312 313 ```python 314 _hash_file('/path/to/file.zip') 315 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' 316 ``` 317 318 Args: 319 fpath: path to the file being validated 320 algorithm: hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`. 321 The default `'auto'` detects the hash algorithm in use. 322 chunk_size: Bytes to read at a time, important for large files. 323 324 Returns: 325 The file hash 326 """ 327 if (algorithm == 'sha256') or (algorithm == 'auto' and len(hash) == 64): 328 hasher = hashlib.sha256() 329 else: 330 hasher = hashlib.md5() 331 332 with open(fpath, 'rb') as fpath_file: 333 for chunk in iter(lambda: fpath_file.read(chunk_size), b''): 334 hasher.update(chunk) 335 336 return hasher.hexdigest() 337 338 339def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535): 340 """Validates a file against a sha256 or md5 hash. 341 342 Args: 343 fpath: path to the file being validated 344 file_hash: The expected hash string of the file. 345 The sha256 and md5 hash algorithms are both supported. 346 algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'. 347 The default 'auto' detects the hash algorithm in use. 348 chunk_size: Bytes to read at a time, important for large files. 349 350 Returns: 351 Whether the file is valid 352 """ 353 if (algorithm == 'sha256') or (algorithm == 'auto' and len(file_hash) == 64): 354 hasher = 'sha256' 355 else: 356 hasher = 'md5' 357 358 if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash): 359 return True 360 else: 361 return False 362 363 364class ThreadsafeIter(object): 365 """Wrap an iterator with a lock and propagate exceptions to all threads.""" 366 367 def __init__(self, it): 368 self.it = it 369 self.lock = threading.Lock() 370 371 # After a generator throws an exception all subsequent next() calls raise a 372 # StopIteration Exception. This, however, presents an issue when mixing 373 # generators and threading because it means the order of retrieval need not 374 # match the order in which the generator was called. This can make it appear 375 # that a generator exited normally when in fact the terminating exception is 376 # just in a different thread. In order to provide thread safety, once 377 # self.it has thrown an exception we continue to throw the same exception. 378 self._exception = None 379 380 def __iter__(self): 381 return self 382 383 def next(self): 384 return self.__next__() 385 386 def __next__(self): 387 with self.lock: 388 if self._exception: 389 raise self._exception # pylint: disable=raising-bad-type 390 391 try: 392 return next(self.it) 393 except Exception as e: 394 self._exception = e 395 raise 396 397 398def threadsafe_generator(f): 399 400 @functools.wraps(f) 401 def g(*a, **kw): 402 return ThreadsafeIter(f(*a, **kw)) 403 404 return g 405 406 407@keras_export('keras.utils.Sequence') 408class Sequence(object): 409 """Base object for fitting to a sequence of data, such as a dataset. 410 411 Every `Sequence` must implement the `__getitem__` and the `__len__` methods. 412 If you want to modify your dataset between epochs you may implement 413 `on_epoch_end`. 414 The method `__getitem__` should return a complete batch. 415 416 Notes: 417 418 `Sequence` are a safer way to do multiprocessing. This structure guarantees 419 that the network will only train once 420 on each sample per epoch which is not the case with generators. 421 422 Examples: 423 424 ```python 425 from skimage.io import imread 426 from skimage.transform import resize 427 import numpy as np 428 import math 429 430 # Here, `x_set` is list of path to the images 431 # and `y_set` are the associated classes. 432 433 class CIFAR10Sequence(Sequence): 434 435 def __init__(self, x_set, y_set, batch_size): 436 self.x, self.y = x_set, y_set 437 self.batch_size = batch_size 438 439 def __len__(self): 440 return math.ceil(len(self.x) / self.batch_size) 441 442 def __getitem__(self, idx): 443 batch_x = self.x[idx * self.batch_size:(idx + 1) * 444 self.batch_size] 445 batch_y = self.y[idx * self.batch_size:(idx + 1) * 446 self.batch_size] 447 448 return np.array([ 449 resize(imread(file_name), (200, 200)) 450 for file_name in batch_x]), np.array(batch_y) 451 ``` 452 """ 453 454 @abstractmethod 455 def __getitem__(self, index): 456 """Gets batch at position `index`. 457 458 Args: 459 index: position of the batch in the Sequence. 460 461 Returns: 462 A batch 463 """ 464 raise NotImplementedError 465 466 @abstractmethod 467 def __len__(self): 468 """Number of batch in the Sequence. 469 470 Returns: 471 The number of batches in the Sequence. 472 """ 473 raise NotImplementedError 474 475 def on_epoch_end(self): 476 """Method called at the end of every epoch. 477 """ 478 pass 479 480 def __iter__(self): 481 """Create a generator that iterate over the Sequence.""" 482 for item in (self[i] for i in range(len(self))): 483 yield item 484 485 486def iter_sequence_infinite(seq): 487 """Iterates indefinitely over a Sequence. 488 489 Args: 490 seq: `Sequence` instance. 491 492 Yields: 493 Batches of data from the `Sequence`. 494 """ 495 while True: 496 for item in seq: 497 yield item 498 499 500# Global variables to be shared across processes 501_SHARED_SEQUENCES = {} 502# We use a Value to provide unique id to different processes. 503_SEQUENCE_COUNTER = None 504 505 506# Because multiprocessing pools are inherently unsafe, starting from a clean 507# state can be essential to avoiding deadlocks. In order to accomplish this, we 508# need to be able to check on the status of Pools that we create. 509_DATA_POOLS = weakref.WeakSet() 510_WORKER_ID_QUEUE = None # Only created if needed. 511_WORKER_IDS = set() 512_FORCE_THREADPOOL = False 513_FORCE_THREADPOOL_LOCK = threading.RLock() 514 515 516def dont_use_multiprocessing_pool(f): 517 @functools.wraps(f) 518 def wrapped(*args, **kwargs): 519 with _FORCE_THREADPOOL_LOCK: 520 global _FORCE_THREADPOOL 521 old_force_threadpool, _FORCE_THREADPOOL = _FORCE_THREADPOOL, True 522 out = f(*args, **kwargs) 523 _FORCE_THREADPOOL = old_force_threadpool 524 return out 525 return wrapped 526 527 528def get_pool_class(use_multiprocessing): 529 global _FORCE_THREADPOOL 530 if not use_multiprocessing or _FORCE_THREADPOOL: 531 return multiprocessing.dummy.Pool # ThreadPool 532 return multiprocessing.Pool 533 534 535def get_worker_id_queue(): 536 """Lazily create the queue to track worker ids.""" 537 global _WORKER_ID_QUEUE 538 if _WORKER_ID_QUEUE is None: 539 _WORKER_ID_QUEUE = multiprocessing.Queue() 540 return _WORKER_ID_QUEUE 541 542 543def init_pool(seqs): 544 global _SHARED_SEQUENCES 545 _SHARED_SEQUENCES = seqs 546 547 548def get_index(uid, i): 549 """Get the value from the Sequence `uid` at index `i`. 550 551 To allow multiple Sequences to be used at the same time, we use `uid` to 552 get a specific one. A single Sequence would cause the validation to 553 overwrite the training Sequence. 554 555 Args: 556 uid: int, Sequence identifier 557 i: index 558 559 Returns: 560 The value at index `i`. 561 """ 562 return _SHARED_SEQUENCES[uid][i] 563 564 565@keras_export('keras.utils.SequenceEnqueuer') 566class SequenceEnqueuer(object): 567 """Base class to enqueue inputs. 568 569 The task of an Enqueuer is to use parallelism to speed up preprocessing. 570 This is done with processes or threads. 571 572 Example: 573 574 ```python 575 enqueuer = SequenceEnqueuer(...) 576 enqueuer.start() 577 datas = enqueuer.get() 578 for data in datas: 579 # Use the inputs; training, evaluating, predicting. 580 # ... stop sometime. 581 enqueuer.stop() 582 ``` 583 584 The `enqueuer.get()` should be an infinite stream of datas. 585 """ 586 587 def __init__(self, sequence, 588 use_multiprocessing=False): 589 self.sequence = sequence 590 self.use_multiprocessing = use_multiprocessing 591 592 global _SEQUENCE_COUNTER 593 if _SEQUENCE_COUNTER is None: 594 try: 595 _SEQUENCE_COUNTER = multiprocessing.Value('i', 0) 596 except OSError: 597 # In this case the OS does not allow us to use 598 # multiprocessing. We resort to an int 599 # for enqueuer indexing. 600 _SEQUENCE_COUNTER = 0 601 602 if isinstance(_SEQUENCE_COUNTER, int): 603 self.uid = _SEQUENCE_COUNTER 604 _SEQUENCE_COUNTER += 1 605 else: 606 # Doing Multiprocessing.Value += x is not process-safe. 607 with _SEQUENCE_COUNTER.get_lock(): 608 self.uid = _SEQUENCE_COUNTER.value 609 _SEQUENCE_COUNTER.value += 1 610 611 self.workers = 0 612 self.executor_fn = None 613 self.queue = None 614 self.run_thread = None 615 self.stop_signal = None 616 617 def is_running(self): 618 return self.stop_signal is not None and not self.stop_signal.is_set() 619 620 def start(self, workers=1, max_queue_size=10): 621 """Starts the handler's workers. 622 623 Args: 624 workers: Number of workers. 625 max_queue_size: queue size 626 (when full, workers could block on `put()`) 627 """ 628 if self.use_multiprocessing: 629 self.executor_fn = self._get_executor_init(workers) 630 else: 631 # We do not need the init since it's threads. 632 self.executor_fn = lambda _: get_pool_class(False)(workers) 633 self.workers = workers 634 self.queue = queue.Queue(max_queue_size) 635 self.stop_signal = threading.Event() 636 self.run_thread = threading.Thread(target=self._run) 637 self.run_thread.daemon = True 638 self.run_thread.start() 639 640 def _send_sequence(self): 641 """Sends current Iterable to all workers.""" 642 # For new processes that may spawn 643 _SHARED_SEQUENCES[self.uid] = self.sequence 644 645 def stop(self, timeout=None): 646 """Stops running threads and wait for them to exit, if necessary. 647 648 Should be called by the same thread which called `start()`. 649 650 Args: 651 timeout: maximum time to wait on `thread.join()` 652 """ 653 self.stop_signal.set() 654 with self.queue.mutex: 655 self.queue.queue.clear() 656 self.queue.unfinished_tasks = 0 657 self.queue.not_full.notify() 658 self.run_thread.join(timeout) 659 _SHARED_SEQUENCES[self.uid] = None 660 661 def __del__(self): 662 if self.is_running(): 663 self.stop() 664 665 @abstractmethod 666 def _run(self): 667 """Submits request to the executor and queue the `Future` objects.""" 668 raise NotImplementedError 669 670 @abstractmethod 671 def _get_executor_init(self, workers): 672 """Gets the Pool initializer for multiprocessing. 673 674 Args: 675 workers: Number of workers. 676 677 Returns: 678 Function, a Function to initialize the pool 679 """ 680 raise NotImplementedError 681 682 @abstractmethod 683 def get(self): 684 """Creates a generator to extract data from the queue. 685 686 Skip the data if it is `None`. 687 # Returns 688 Generator yielding tuples `(inputs, targets)` 689 or `(inputs, targets, sample_weights)`. 690 """ 691 raise NotImplementedError 692 693 694@keras_export('keras.utils.OrderedEnqueuer') 695class OrderedEnqueuer(SequenceEnqueuer): 696 """Builds a Enqueuer from a Sequence. 697 698 Used in `fit_generator`, `evaluate_generator`, `predict_generator`. 699 700 Args: 701 sequence: A `tf.keras.utils.data_utils.Sequence` object. 702 use_multiprocessing: use multiprocessing if True, otherwise threading 703 shuffle: whether to shuffle the data at the beginning of each epoch 704 """ 705 706 def __init__(self, sequence, use_multiprocessing=False, shuffle=False): 707 super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing) 708 self.shuffle = shuffle 709 710 def _get_executor_init(self, workers): 711 """Gets the Pool initializer for multiprocessing. 712 713 Args: 714 workers: Number of workers. 715 716 Returns: 717 Function, a Function to initialize the pool 718 """ 719 def pool_fn(seqs): 720 pool = get_pool_class(True)( 721 workers, initializer=init_pool_generator, 722 initargs=(seqs, None, get_worker_id_queue())) 723 _DATA_POOLS.add(pool) 724 return pool 725 726 return pool_fn 727 728 def _wait_queue(self): 729 """Wait for the queue to be empty.""" 730 while True: 731 time.sleep(0.1) 732 if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set(): 733 return 734 735 def _run(self): 736 """Submits request to the executor and queue the `Future` objects.""" 737 sequence = list(range(len(self.sequence))) 738 self._send_sequence() # Share the initial sequence 739 while True: 740 if self.shuffle: 741 random.shuffle(sequence) 742 743 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: 744 for i in sequence: 745 if self.stop_signal.is_set(): 746 return 747 748 self.queue.put( 749 executor.apply_async(get_index, (self.uid, i)), block=True) 750 751 # Done with the current epoch, waiting for the final batches 752 self._wait_queue() 753 754 if self.stop_signal.is_set(): 755 # We're done 756 return 757 758 # Call the internal on epoch end. 759 self.sequence.on_epoch_end() 760 self._send_sequence() # Update the pool 761 762 def get(self): 763 """Creates a generator to extract data from the queue. 764 765 Skip the data if it is `None`. 766 767 Yields: 768 The next element in the queue, i.e. a tuple 769 `(inputs, targets)` or 770 `(inputs, targets, sample_weights)`. 771 """ 772 while self.is_running(): 773 try: 774 inputs = self.queue.get(block=True, timeout=5).get() 775 if self.is_running(): 776 self.queue.task_done() 777 if inputs is not None: 778 yield inputs 779 except queue.Empty: 780 pass 781 except Exception: # pylint: disable=broad-except 782 self.stop() 783 six.reraise(*sys.exc_info()) 784 785 786def init_pool_generator(gens, random_seed=None, id_queue=None): 787 """Initializer function for pool workers. 788 789 Args: 790 gens: State which should be made available to worker processes. 791 random_seed: An optional value with which to seed child processes. 792 id_queue: A multiprocessing Queue of worker ids. This is used to indicate 793 that a worker process was created by Keras and can be terminated using 794 the cleanup_all_keras_forkpools utility. 795 """ 796 global _SHARED_SEQUENCES 797 _SHARED_SEQUENCES = gens 798 799 worker_proc = multiprocessing.current_process() 800 801 # name isn't used for anything, but setting a more descriptive name is helpful 802 # when diagnosing orphaned processes. 803 worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name) 804 805 if random_seed is not None: 806 np.random.seed(random_seed + worker_proc.ident) 807 808 if id_queue is not None: 809 # If a worker dies during init, the pool will just create a replacement. 810 id_queue.put(worker_proc.ident, block=True, timeout=0.1) 811 812 813def next_sample(uid): 814 """Gets the next value from the generator `uid`. 815 816 To allow multiple generators to be used at the same time, we use `uid` to 817 get a specific one. A single generator would cause the validation to 818 overwrite the training generator. 819 820 Args: 821 uid: int, generator identifier 822 823 Returns: 824 The next value of generator `uid`. 825 """ 826 return six.next(_SHARED_SEQUENCES[uid]) 827 828 829@keras_export('keras.utils.GeneratorEnqueuer') 830class GeneratorEnqueuer(SequenceEnqueuer): 831 """Builds a queue out of a data generator. 832 833 The provided generator can be finite in which case the class will throw 834 a `StopIteration` exception. 835 836 Used in `fit_generator`, `evaluate_generator`, `predict_generator`. 837 838 Args: 839 generator: a generator function which yields data 840 use_multiprocessing: use multiprocessing if True, otherwise threading 841 wait_time: time to sleep in-between calls to `put()` 842 random_seed: Initial seed for workers, 843 will be incremented by one for each worker. 844 """ 845 846 def __init__(self, sequence, 847 use_multiprocessing=False, 848 random_seed=None): 849 super(GeneratorEnqueuer, self).__init__(sequence, use_multiprocessing) 850 self.random_seed = random_seed 851 852 def _get_executor_init(self, workers): 853 """Gets the Pool initializer for multiprocessing. 854 855 Args: 856 workers: Number of works. 857 858 Returns: 859 A Function to initialize the pool 860 """ 861 def pool_fn(seqs): 862 pool = get_pool_class(True)( 863 workers, initializer=init_pool_generator, 864 initargs=(seqs, self.random_seed, get_worker_id_queue())) 865 _DATA_POOLS.add(pool) 866 return pool 867 return pool_fn 868 869 def _run(self): 870 """Submits request to the executor and queue the `Future` objects.""" 871 self._send_sequence() # Share the initial generator 872 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: 873 while True: 874 if self.stop_signal.is_set(): 875 return 876 877 self.queue.put( 878 executor.apply_async(next_sample, (self.uid,)), block=True) 879 880 def get(self): 881 """Creates a generator to extract data from the queue. 882 883 Skip the data if it is `None`. 884 885 Yields: 886 The next element in the queue, i.e. a tuple 887 `(inputs, targets)` or 888 `(inputs, targets, sample_weights)`. 889 """ 890 try: 891 while self.is_running(): 892 inputs = self.queue.get(block=True).get() 893 self.queue.task_done() 894 if inputs is not None: 895 yield inputs 896 except StopIteration: 897 # Special case for finite generators 898 last_ones = [] 899 while self.queue.qsize() > 0: 900 last_ones.append(self.queue.get(block=True)) 901 # Wait for them to complete 902 for f in last_ones: 903 f.wait() 904 # Keep the good ones 905 last_ones = [future.get() for future in last_ones if future.successful()] 906 for inputs in last_ones: 907 if inputs is not None: 908 yield inputs 909 except Exception as e: # pylint: disable=broad-except 910 self.stop() 911 if 'generator already executing' in str(e): 912 raise RuntimeError( 913 'Your generator is NOT thread-safe. ' 914 'Keras requires a thread-safe generator when ' 915 '`use_multiprocessing=False, workers > 1`. ') 916 six.reraise(*sys.exc_info()) 917