1# Lint as python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16# pylint: disable=g-import-not-at-top 17"""Utilities for file download and caching.""" 18 19from abc import abstractmethod 20from contextlib import closing 21import functools 22import hashlib 23import multiprocessing 24import multiprocessing.dummy 25import os 26import queue 27import random 28import shutil 29import sys # pylint: disable=unused-import 30import tarfile 31import threading 32import time 33import typing 34import urllib 35import weakref 36import zipfile 37 38import numpy as np 39 40from tensorflow.python.framework import ops 41from six.moves.urllib.request import urlopen 42from tensorflow.python.keras.utils import tf_inspect 43from tensorflow.python.keras.utils.generic_utils import Progbar 44from tensorflow.python.keras.utils.io_utils import path_to_string 45from tensorflow.python.util.tf_export import keras_export 46 47# Required to support google internal urlretrieve 48if sys.version_info[0] == 2: 49 50 def urlretrieve(url, filename, reporthook=None, data=None): 51 """Replacement for `urlretrieve` for Python 2. 52 53 Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy 54 `urllib` module, known to have issues with proxy management. 55 56 Args: 57 url: url to retrieve. 58 filename: where to store the retrieved data locally. 59 reporthook: a hook function that will be called once on establishment of 60 the network connection and once after each block read thereafter. The 61 hook will be passed three arguments; a count of blocks transferred so 62 far, a block size in bytes, and the total size of the file. 63 data: `data` argument passed to `urlopen`. 64 """ 65 66 def chunk_read(response, chunk_size=8192, reporthook=None): 67 content_type = response.info().get('Content-Length') 68 total_size = -1 69 if content_type is not None: 70 total_size = int(content_type.strip()) 71 count = 0 72 while True: 73 chunk = response.read(chunk_size) 74 count += 1 75 if reporthook is not None: 76 reporthook(count, chunk_size, total_size) 77 if chunk: 78 yield chunk 79 else: 80 break 81 82 response = urlopen(url, data) 83 with open(filename, 'wb') as fd: 84 for chunk in chunk_read(response, reporthook=reporthook): 85 fd.write(chunk) 86else: 87 from urllib.request import urlretrieve # pylint: disable=g-importing-member 88 89 90def is_generator_or_sequence(x): 91 """Check if `x` is a Keras generator type.""" 92 builtin_iterators = (str, list, tuple, dict, set, frozenset) 93 if isinstance(x, (ops.Tensor, np.ndarray) + builtin_iterators): 94 return False 95 return (tf_inspect.isgenerator(x) or 96 isinstance(x, Sequence) or 97 isinstance(x, typing.Iterator)) 98 99 100def _extract_archive(file_path, path='.', archive_format='auto'): 101 """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats. 102 103 Args: 104 file_path: path to the archive file 105 path: path to extract the archive file 106 archive_format: Archive format to try for extracting the file. 107 Options are 'auto', 'tar', 'zip', and None. 108 'tar' includes tar, tar.gz, and tar.bz files. 109 The default 'auto' is ['tar', 'zip']. 110 None or an empty list will return no matches found. 111 112 Returns: 113 True if a match was found and an archive extraction was completed, 114 False otherwise. 115 """ 116 if archive_format is None: 117 return False 118 if archive_format == 'auto': 119 archive_format = ['tar', 'zip'] 120 if isinstance(archive_format, str): 121 archive_format = [archive_format] 122 123 file_path = path_to_string(file_path) 124 path = path_to_string(path) 125 126 for archive_type in archive_format: 127 if archive_type == 'tar': 128 open_fn = tarfile.open 129 is_match_fn = tarfile.is_tarfile 130 if archive_type == 'zip': 131 open_fn = zipfile.ZipFile 132 is_match_fn = zipfile.is_zipfile 133 134 if is_match_fn(file_path): 135 with open_fn(file_path) as archive: 136 try: 137 archive.extractall(path) 138 except (tarfile.TarError, RuntimeError, KeyboardInterrupt): 139 if os.path.exists(path): 140 if os.path.isfile(path): 141 os.remove(path) 142 else: 143 shutil.rmtree(path) 144 raise 145 return True 146 return False 147 148 149@keras_export('keras.utils.get_file') 150def get_file(fname, 151 origin, 152 untar=False, 153 md5_hash=None, 154 file_hash=None, 155 cache_subdir='datasets', 156 hash_algorithm='auto', 157 extract=False, 158 archive_format='auto', 159 cache_dir=None): 160 """Downloads a file from a URL if it not already in the cache. 161 162 By default the file at the url `origin` is downloaded to the 163 cache_dir `~/.keras`, placed in the cache_subdir `datasets`, 164 and given the filename `fname`. The final location of a file 165 `example.txt` would therefore be `~/.keras/datasets/example.txt`. 166 167 Files in tar, tar.gz, tar.bz, and zip formats can also be extracted. 168 Passing a hash will verify the file after download. The command line 169 programs `shasum` and `sha256sum` can compute the hash. 170 171 Example: 172 173 ```python 174 path_to_downloaded_file = tf.keras.utils.get_file( 175 "flower_photos", 176 "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", 177 untar=True) 178 ``` 179 180 Args: 181 fname: Name of the file. If an absolute path `/path/to/file.txt` is 182 specified the file will be saved at that location. 183 origin: Original URL of the file. 184 untar: Deprecated in favor of `extract` argument. 185 boolean, whether the file should be decompressed 186 md5_hash: Deprecated in favor of `file_hash` argument. 187 md5 hash of the file for verification 188 file_hash: The expected hash string of the file after download. 189 The sha256 and md5 hash algorithms are both supported. 190 cache_subdir: Subdirectory under the Keras cache dir where the file is 191 saved. If an absolute path `/path/to/folder` is 192 specified the file will be saved at that location. 193 hash_algorithm: Select the hash algorithm to verify the file. 194 options are `'md5'`, `'sha256'`, and `'auto'`. 195 The default 'auto' detects the hash algorithm in use. 196 extract: True tries extracting the file as an Archive, like tar or zip. 197 archive_format: Archive format to try for extracting the file. 198 Options are `'auto'`, `'tar'`, `'zip'`, and `None`. 199 `'tar'` includes tar, tar.gz, and tar.bz files. 200 The default `'auto'` corresponds to `['tar', 'zip']`. 201 None or an empty list will return no matches found. 202 cache_dir: Location to store cached files, when None it 203 defaults to the default directory `~/.keras/`. 204 205 Returns: 206 Path to the downloaded file 207 """ 208 if cache_dir is None: 209 cache_dir = os.path.join(os.path.expanduser('~'), '.keras') 210 if md5_hash is not None and file_hash is None: 211 file_hash = md5_hash 212 hash_algorithm = 'md5' 213 datadir_base = os.path.expanduser(cache_dir) 214 if not os.access(datadir_base, os.W_OK): 215 datadir_base = os.path.join('/tmp', '.keras') 216 datadir = os.path.join(datadir_base, cache_subdir) 217 _makedirs_exist_ok(datadir) 218 219 fname = path_to_string(fname) 220 221 if untar: 222 untar_fpath = os.path.join(datadir, fname) 223 fpath = untar_fpath + '.tar.gz' 224 else: 225 fpath = os.path.join(datadir, fname) 226 227 download = False 228 if os.path.exists(fpath): 229 # File found; verify integrity if a hash was provided. 230 if file_hash is not None: 231 if not validate_file(fpath, file_hash, algorithm=hash_algorithm): 232 print('A local file was found, but it seems to be ' 233 'incomplete or outdated because the ' + hash_algorithm + 234 ' file hash does not match the original value of ' + file_hash + 235 ' so we will re-download the data.') 236 download = True 237 else: 238 download = True 239 240 if download: 241 print('Downloading data from', origin) 242 243 class ProgressTracker(object): 244 # Maintain progbar for the lifetime of download. 245 # This design was chosen for Python 2.7 compatibility. 246 progbar = None 247 248 def dl_progress(count, block_size, total_size): 249 if ProgressTracker.progbar is None: 250 if total_size == -1: 251 total_size = None 252 ProgressTracker.progbar = Progbar(total_size) 253 else: 254 ProgressTracker.progbar.update(count * block_size) 255 256 error_msg = 'URL fetch failure on {}: {} -- {}' 257 try: 258 try: 259 urlretrieve(origin, fpath, dl_progress) 260 except urllib.error.HTTPError as e: 261 raise Exception(error_msg.format(origin, e.code, e.msg)) 262 except urllib.error.URLError as e: 263 raise Exception(error_msg.format(origin, e.errno, e.reason)) 264 except (Exception, KeyboardInterrupt) as e: 265 if os.path.exists(fpath): 266 os.remove(fpath) 267 raise 268 ProgressTracker.progbar = None 269 270 if untar: 271 if not os.path.exists(untar_fpath): 272 _extract_archive(fpath, datadir, archive_format='tar') 273 return untar_fpath 274 275 if extract: 276 _extract_archive(fpath, datadir, archive_format) 277 278 return fpath 279 280 281def _makedirs_exist_ok(datadir): 282 os.makedirs(datadir, exist_ok=True) # pylint: disable=unexpected-keyword-arg 283 284 285def _resolve_hasher(algorithm, file_hash=None): 286 """Returns hash algorithm as hashlib function.""" 287 if algorithm == 'sha256': 288 return hashlib.sha256() 289 290 if algorithm == 'auto' and file_hash is not None and len(file_hash) == 64: 291 return hashlib.sha256() 292 293 # This is used only for legacy purposes. 294 return hashlib.md5() 295 296 297def _hash_file(fpath, algorithm='sha256', chunk_size=65535): 298 """Calculates a file sha256 or md5 hash. 299 300 Example: 301 302 ```python 303 _hash_file('/path/to/file.zip') 304 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' 305 ``` 306 307 Args: 308 fpath: path to the file being validated 309 algorithm: hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`. 310 The default `'auto'` detects the hash algorithm in use. 311 chunk_size: Bytes to read at a time, important for large files. 312 313 Returns: 314 The file hash 315 """ 316 if isinstance(algorithm, str): 317 hasher = _resolve_hasher(algorithm) 318 else: 319 hasher = algorithm 320 321 with open(fpath, 'rb') as fpath_file: 322 for chunk in iter(lambda: fpath_file.read(chunk_size), b''): 323 hasher.update(chunk) 324 325 return hasher.hexdigest() 326 327 328def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535): 329 """Validates a file against a sha256 or md5 hash. 330 331 Args: 332 fpath: path to the file being validated 333 file_hash: The expected hash string of the file. 334 The sha256 and md5 hash algorithms are both supported. 335 algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'. 336 The default 'auto' detects the hash algorithm in use. 337 chunk_size: Bytes to read at a time, important for large files. 338 339 Returns: 340 Whether the file is valid 341 """ 342 hasher = _resolve_hasher(algorithm, file_hash) 343 344 if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash): 345 return True 346 else: 347 return False 348 349 350class ThreadsafeIter(object): 351 """Wrap an iterator with a lock and propagate exceptions to all threads.""" 352 353 def __init__(self, it): 354 self.it = it 355 self.lock = threading.Lock() 356 357 # After a generator throws an exception all subsequent next() calls raise a 358 # StopIteration Exception. This, however, presents an issue when mixing 359 # generators and threading because it means the order of retrieval need not 360 # match the order in which the generator was called. This can make it appear 361 # that a generator exited normally when in fact the terminating exception is 362 # just in a different thread. In order to provide thread safety, once 363 # self.it has thrown an exception we continue to throw the same exception. 364 self._exception = None 365 366 def __iter__(self): 367 return self 368 369 def next(self): 370 return self.__next__() 371 372 def __next__(self): 373 with self.lock: 374 if self._exception: 375 raise self._exception # pylint: disable=raising-bad-type 376 377 try: 378 return next(self.it) 379 except Exception as e: 380 self._exception = e 381 raise 382 383 384def threadsafe_generator(f): 385 386 @functools.wraps(f) 387 def g(*a, **kw): 388 return ThreadsafeIter(f(*a, **kw)) 389 390 return g 391 392 393@keras_export('keras.utils.Sequence') 394class Sequence(object): 395 """Base object for fitting to a sequence of data, such as a dataset. 396 397 Every `Sequence` must implement the `__getitem__` and the `__len__` methods. 398 If you want to modify your dataset between epochs you may implement 399 `on_epoch_end`. 400 The method `__getitem__` should return a complete batch. 401 402 Notes: 403 404 `Sequence` are a safer way to do multiprocessing. This structure guarantees 405 that the network will only train once 406 on each sample per epoch which is not the case with generators. 407 408 Examples: 409 410 ```python 411 from skimage.io import imread 412 from skimage.transform import resize 413 import numpy as np 414 import math 415 416 # Here, `x_set` is list of path to the images 417 # and `y_set` are the associated classes. 418 419 class CIFAR10Sequence(Sequence): 420 421 def __init__(self, x_set, y_set, batch_size): 422 self.x, self.y = x_set, y_set 423 self.batch_size = batch_size 424 425 def __len__(self): 426 return math.ceil(len(self.x) / self.batch_size) 427 428 def __getitem__(self, idx): 429 batch_x = self.x[idx * self.batch_size:(idx + 1) * 430 self.batch_size] 431 batch_y = self.y[idx * self.batch_size:(idx + 1) * 432 self.batch_size] 433 434 return np.array([ 435 resize(imread(file_name), (200, 200)) 436 for file_name in batch_x]), np.array(batch_y) 437 ``` 438 """ 439 440 @abstractmethod 441 def __getitem__(self, index): 442 """Gets batch at position `index`. 443 444 Args: 445 index: position of the batch in the Sequence. 446 447 Returns: 448 A batch 449 """ 450 raise NotImplementedError 451 452 @abstractmethod 453 def __len__(self): 454 """Number of batch in the Sequence. 455 456 Returns: 457 The number of batches in the Sequence. 458 """ 459 raise NotImplementedError 460 461 def on_epoch_end(self): 462 """Method called at the end of every epoch. 463 """ 464 pass 465 466 def __iter__(self): 467 """Create a generator that iterate over the Sequence.""" 468 for item in (self[i] for i in range(len(self))): 469 yield item 470 471 472def iter_sequence_infinite(seq): 473 """Iterates indefinitely over a Sequence. 474 475 Args: 476 seq: `Sequence` instance. 477 478 Yields: 479 Batches of data from the `Sequence`. 480 """ 481 while True: 482 for item in seq: 483 yield item 484 485 486# Global variables to be shared across processes 487_SHARED_SEQUENCES = {} 488# We use a Value to provide unique id to different processes. 489_SEQUENCE_COUNTER = None 490 491 492# Because multiprocessing pools are inherently unsafe, starting from a clean 493# state can be essential to avoiding deadlocks. In order to accomplish this, we 494# need to be able to check on the status of Pools that we create. 495_DATA_POOLS = weakref.WeakSet() 496_WORKER_ID_QUEUE = None # Only created if needed. 497_WORKER_IDS = set() 498_FORCE_THREADPOOL = False 499_FORCE_THREADPOOL_LOCK = threading.RLock() 500 501 502def dont_use_multiprocessing_pool(f): 503 @functools.wraps(f) 504 def wrapped(*args, **kwargs): 505 with _FORCE_THREADPOOL_LOCK: 506 global _FORCE_THREADPOOL 507 old_force_threadpool, _FORCE_THREADPOOL = _FORCE_THREADPOOL, True 508 out = f(*args, **kwargs) 509 _FORCE_THREADPOOL = old_force_threadpool 510 return out 511 return wrapped 512 513 514def get_pool_class(use_multiprocessing): 515 global _FORCE_THREADPOOL 516 if not use_multiprocessing or _FORCE_THREADPOOL: 517 return multiprocessing.dummy.Pool # ThreadPool 518 return multiprocessing.Pool 519 520 521def get_worker_id_queue(): 522 """Lazily create the queue to track worker ids.""" 523 global _WORKER_ID_QUEUE 524 if _WORKER_ID_QUEUE is None: 525 _WORKER_ID_QUEUE = multiprocessing.Queue() 526 return _WORKER_ID_QUEUE 527 528 529def init_pool(seqs): 530 global _SHARED_SEQUENCES 531 _SHARED_SEQUENCES = seqs 532 533 534def get_index(uid, i): 535 """Get the value from the Sequence `uid` at index `i`. 536 537 To allow multiple Sequences to be used at the same time, we use `uid` to 538 get a specific one. A single Sequence would cause the validation to 539 overwrite the training Sequence. 540 541 Args: 542 uid: int, Sequence identifier 543 i: index 544 545 Returns: 546 The value at index `i`. 547 """ 548 return _SHARED_SEQUENCES[uid][i] 549 550 551@keras_export('keras.utils.SequenceEnqueuer') 552class SequenceEnqueuer(object): 553 """Base class to enqueue inputs. 554 555 The task of an Enqueuer is to use parallelism to speed up preprocessing. 556 This is done with processes or threads. 557 558 Example: 559 560 ```python 561 enqueuer = SequenceEnqueuer(...) 562 enqueuer.start() 563 datas = enqueuer.get() 564 for data in datas: 565 # Use the inputs; training, evaluating, predicting. 566 # ... stop sometime. 567 enqueuer.stop() 568 ``` 569 570 The `enqueuer.get()` should be an infinite stream of datas. 571 """ 572 573 def __init__(self, sequence, 574 use_multiprocessing=False): 575 self.sequence = sequence 576 self.use_multiprocessing = use_multiprocessing 577 578 global _SEQUENCE_COUNTER 579 if _SEQUENCE_COUNTER is None: 580 try: 581 _SEQUENCE_COUNTER = multiprocessing.Value('i', 0) 582 except OSError: 583 # In this case the OS does not allow us to use 584 # multiprocessing. We resort to an int 585 # for enqueuer indexing. 586 _SEQUENCE_COUNTER = 0 587 588 if isinstance(_SEQUENCE_COUNTER, int): 589 self.uid = _SEQUENCE_COUNTER 590 _SEQUENCE_COUNTER += 1 591 else: 592 # Doing Multiprocessing.Value += x is not process-safe. 593 with _SEQUENCE_COUNTER.get_lock(): 594 self.uid = _SEQUENCE_COUNTER.value 595 _SEQUENCE_COUNTER.value += 1 596 597 self.workers = 0 598 self.executor_fn = None 599 self.queue = None 600 self.run_thread = None 601 self.stop_signal = None 602 603 def is_running(self): 604 return self.stop_signal is not None and not self.stop_signal.is_set() 605 606 def start(self, workers=1, max_queue_size=10): 607 """Starts the handler's workers. 608 609 Args: 610 workers: Number of workers. 611 max_queue_size: queue size 612 (when full, workers could block on `put()`) 613 """ 614 if self.use_multiprocessing: 615 self.executor_fn = self._get_executor_init(workers) 616 else: 617 # We do not need the init since it's threads. 618 self.executor_fn = lambda _: get_pool_class(False)(workers) 619 self.workers = workers 620 self.queue = queue.Queue(max_queue_size) 621 self.stop_signal = threading.Event() 622 self.run_thread = threading.Thread(target=self._run) 623 self.run_thread.daemon = True 624 self.run_thread.start() 625 626 def _send_sequence(self): 627 """Sends current Iterable to all workers.""" 628 # For new processes that may spawn 629 _SHARED_SEQUENCES[self.uid] = self.sequence 630 631 def stop(self, timeout=None): 632 """Stops running threads and wait for them to exit, if necessary. 633 634 Should be called by the same thread which called `start()`. 635 636 Args: 637 timeout: maximum time to wait on `thread.join()` 638 """ 639 self.stop_signal.set() 640 with self.queue.mutex: 641 self.queue.queue.clear() 642 self.queue.unfinished_tasks = 0 643 self.queue.not_full.notify() 644 self.run_thread.join(timeout) 645 _SHARED_SEQUENCES[self.uid] = None 646 647 def __del__(self): 648 if self.is_running(): 649 self.stop() 650 651 @abstractmethod 652 def _run(self): 653 """Submits request to the executor and queue the `Future` objects.""" 654 raise NotImplementedError 655 656 @abstractmethod 657 def _get_executor_init(self, workers): 658 """Gets the Pool initializer for multiprocessing. 659 660 Args: 661 workers: Number of workers. 662 663 Returns: 664 Function, a Function to initialize the pool 665 """ 666 raise NotImplementedError 667 668 @abstractmethod 669 def get(self): 670 """Creates a generator to extract data from the queue. 671 672 Skip the data if it is `None`. 673 # Returns 674 Generator yielding tuples `(inputs, targets)` 675 or `(inputs, targets, sample_weights)`. 676 """ 677 raise NotImplementedError 678 679 680@keras_export('keras.utils.OrderedEnqueuer') 681class OrderedEnqueuer(SequenceEnqueuer): 682 """Builds a Enqueuer from a Sequence. 683 684 Args: 685 sequence: A `tf.keras.utils.data_utils.Sequence` object. 686 use_multiprocessing: use multiprocessing if True, otherwise threading 687 shuffle: whether to shuffle the data at the beginning of each epoch 688 """ 689 690 def __init__(self, sequence, use_multiprocessing=False, shuffle=False): 691 super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing) 692 self.shuffle = shuffle 693 694 def _get_executor_init(self, workers): 695 """Gets the Pool initializer for multiprocessing. 696 697 Args: 698 workers: Number of workers. 699 700 Returns: 701 Function, a Function to initialize the pool 702 """ 703 def pool_fn(seqs): 704 pool = get_pool_class(True)( 705 workers, initializer=init_pool_generator, 706 initargs=(seqs, None, get_worker_id_queue())) 707 _DATA_POOLS.add(pool) 708 return pool 709 710 return pool_fn 711 712 def _wait_queue(self): 713 """Wait for the queue to be empty.""" 714 while True: 715 time.sleep(0.1) 716 if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set(): 717 return 718 719 def _run(self): 720 """Submits request to the executor and queue the `Future` objects.""" 721 sequence = list(range(len(self.sequence))) 722 self._send_sequence() # Share the initial sequence 723 while True: 724 if self.shuffle: 725 random.shuffle(sequence) 726 727 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: 728 for i in sequence: 729 if self.stop_signal.is_set(): 730 return 731 732 self.queue.put( 733 executor.apply_async(get_index, (self.uid, i)), block=True) 734 735 # Done with the current epoch, waiting for the final batches 736 self._wait_queue() 737 738 if self.stop_signal.is_set(): 739 # We're done 740 return 741 742 # Call the internal on epoch end. 743 self.sequence.on_epoch_end() 744 self._send_sequence() # Update the pool 745 746 def get(self): 747 """Creates a generator to extract data from the queue. 748 749 Skip the data if it is `None`. 750 751 Yields: 752 The next element in the queue, i.e. a tuple 753 `(inputs, targets)` or 754 `(inputs, targets, sample_weights)`. 755 """ 756 while self.is_running(): 757 try: 758 inputs = self.queue.get(block=True, timeout=5).get() 759 if self.is_running(): 760 self.queue.task_done() 761 if inputs is not None: 762 yield inputs 763 except queue.Empty: 764 pass 765 except Exception as e: # pylint: disable=broad-except 766 self.stop() 767 raise e 768 769 770def init_pool_generator(gens, random_seed=None, id_queue=None): 771 """Initializer function for pool workers. 772 773 Args: 774 gens: State which should be made available to worker processes. 775 random_seed: An optional value with which to seed child processes. 776 id_queue: A multiprocessing Queue of worker ids. This is used to indicate 777 that a worker process was created by Keras and can be terminated using 778 the cleanup_all_keras_forkpools utility. 779 """ 780 global _SHARED_SEQUENCES 781 _SHARED_SEQUENCES = gens 782 783 worker_proc = multiprocessing.current_process() 784 785 # name isn't used for anything, but setting a more descriptive name is helpful 786 # when diagnosing orphaned processes. 787 worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name) 788 789 if random_seed is not None: 790 np.random.seed(random_seed + worker_proc.ident) 791 792 if id_queue is not None: 793 # If a worker dies during init, the pool will just create a replacement. 794 id_queue.put(worker_proc.ident, block=True, timeout=0.1) 795 796 797def next_sample(uid): 798 """Gets the next value from the generator `uid`. 799 800 To allow multiple generators to be used at the same time, we use `uid` to 801 get a specific one. A single generator would cause the validation to 802 overwrite the training generator. 803 804 Args: 805 uid: int, generator identifier 806 807 Returns: 808 The next value of generator `uid`. 809 """ 810 return next(_SHARED_SEQUENCES[uid]) 811 812 813@keras_export('keras.utils.GeneratorEnqueuer') 814class GeneratorEnqueuer(SequenceEnqueuer): 815 """Builds a queue out of a data generator. 816 817 The provided generator can be finite in which case the class will throw 818 a `StopIteration` exception. 819 820 Args: 821 generator: a generator function which yields data 822 use_multiprocessing: use multiprocessing if True, otherwise threading 823 random_seed: Initial seed for workers, 824 will be incremented by one for each worker. 825 """ 826 827 def __init__(self, generator, 828 use_multiprocessing=False, 829 random_seed=None): 830 super(GeneratorEnqueuer, self).__init__(generator, use_multiprocessing) 831 self.random_seed = random_seed 832 833 def _get_executor_init(self, workers): 834 """Gets the Pool initializer for multiprocessing. 835 836 Args: 837 workers: Number of works. 838 839 Returns: 840 A Function to initialize the pool 841 """ 842 def pool_fn(seqs): 843 pool = get_pool_class(True)( 844 workers, initializer=init_pool_generator, 845 initargs=(seqs, self.random_seed, get_worker_id_queue())) 846 _DATA_POOLS.add(pool) 847 return pool 848 return pool_fn 849 850 def _run(self): 851 """Submits request to the executor and queue the `Future` objects.""" 852 self._send_sequence() # Share the initial generator 853 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: 854 while True: 855 if self.stop_signal.is_set(): 856 return 857 858 self.queue.put( 859 executor.apply_async(next_sample, (self.uid,)), block=True) 860 861 def get(self): 862 """Creates a generator to extract data from the queue. 863 864 Skip the data if it is `None`. 865 866 Yields: 867 The next element in the queue, i.e. a tuple 868 `(inputs, targets)` or 869 `(inputs, targets, sample_weights)`. 870 """ 871 try: 872 while self.is_running(): 873 inputs = self.queue.get(block=True).get() 874 self.queue.task_done() 875 if inputs is not None: 876 yield inputs 877 except StopIteration: 878 # Special case for finite generators 879 last_ones = [] 880 while self.queue.qsize() > 0: 881 last_ones.append(self.queue.get(block=True)) 882 # Wait for them to complete 883 for f in last_ones: 884 f.wait() 885 # Keep the good ones 886 last_ones = [future.get() for future in last_ones if future.successful()] 887 for inputs in last_ones: 888 if inputs is not None: 889 yield inputs 890 except Exception as e: # pylint: disable=broad-except 891 self.stop() 892 if 'generator already executing' in str(e): 893 raise RuntimeError( 894 'Your generator is NOT thread-safe. ' 895 'Keras requires a thread-safe generator when ' 896 '`use_multiprocessing=False, workers > 1`. ') 897 raise e 898