1# Lint as python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16# pylint: disable=g-import-not-at-top 17"""Utilities for file download and caching.""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from abc import abstractmethod 23from contextlib import closing 24import errno 25import functools 26import gc 27import hashlib 28import multiprocessing 29import multiprocessing.dummy 30import os 31import random 32import shutil 33import signal 34import sys 35import tarfile 36import threading 37import time 38import weakref 39import zipfile 40 41import numpy as np 42import six 43from six.moves.urllib.error import HTTPError 44from six.moves.urllib.error import URLError 45 46from tensorflow.python.framework import ops 47from six.moves.urllib.request import urlopen 48from tensorflow.python.keras.utils.generic_utils import Progbar 49from tensorflow.python.platform import tf_logging as logging 50from tensorflow.python.util import tf_inspect 51from tensorflow.python.util.tf_export import keras_export 52 53 54try: 55 import queue 56except ImportError: 57 import Queue as queue 58 59try: 60 import typing 61 is_iterator = lambda x: isinstance(x, typing.Iterator) 62except ImportError: 63 # Python2 uses next, and Python3 should have typing so __next__ is not needed. 64 is_iterator = lambda x: hasattr(x, '__iter__') and hasattr(x, 'next') 65 66 67if sys.version_info[0] == 2: 68 69 def urlretrieve(url, filename, reporthook=None, data=None): 70 """Replacement for `urlretrieve` for Python 2. 71 72 Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy 73 `urllib` module, known to have issues with proxy management. 74 75 Arguments: 76 url: url to retrieve. 77 filename: where to store the retrieved data locally. 78 reporthook: a hook function that will be called once on establishment of 79 the network connection and once after each block read thereafter. The 80 hook will be passed three arguments; a count of blocks transferred so 81 far, a block size in bytes, and the total size of the file. 82 data: `data` argument passed to `urlopen`. 83 """ 84 85 def chunk_read(response, chunk_size=8192, reporthook=None): 86 content_type = response.info().get('Content-Length') 87 total_size = -1 88 if content_type is not None: 89 total_size = int(content_type.strip()) 90 count = 0 91 while True: 92 chunk = response.read(chunk_size) 93 count += 1 94 if reporthook is not None: 95 reporthook(count, chunk_size, total_size) 96 if chunk: 97 yield chunk 98 else: 99 break 100 101 response = urlopen(url, data) 102 with open(filename, 'wb') as fd: 103 for chunk in chunk_read(response, reporthook=reporthook): 104 fd.write(chunk) 105else: 106 from six.moves.urllib.request import urlretrieve 107 108 109def is_generator_or_sequence(x): 110 """Check if `x` is a Keras generator type.""" 111 builtin_iterators = (str, list, tuple, dict, set, frozenset) 112 if isinstance(x, (ops.Tensor, np.ndarray) + builtin_iterators): 113 return False 114 return tf_inspect.isgenerator(x) or isinstance(x, Sequence) or is_iterator(x) 115 116 117def _extract_archive(file_path, path='.', archive_format='auto'): 118 """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats. 119 120 Arguments: 121 file_path: path to the archive file 122 path: path to extract the archive file 123 archive_format: Archive format to try for extracting the file. 124 Options are 'auto', 'tar', 'zip', and None. 125 'tar' includes tar, tar.gz, and tar.bz files. 126 The default 'auto' is ['tar', 'zip']. 127 None or an empty list will return no matches found. 128 129 Returns: 130 True if a match was found and an archive extraction was completed, 131 False otherwise. 132 """ 133 if archive_format is None: 134 return False 135 if archive_format == 'auto': 136 archive_format = ['tar', 'zip'] 137 if isinstance(archive_format, six.string_types): 138 archive_format = [archive_format] 139 140 for archive_type in archive_format: 141 if archive_type == 'tar': 142 open_fn = tarfile.open 143 is_match_fn = tarfile.is_tarfile 144 if archive_type == 'zip': 145 open_fn = zipfile.ZipFile 146 is_match_fn = zipfile.is_zipfile 147 148 if is_match_fn(file_path): 149 with open_fn(file_path) as archive: 150 try: 151 archive.extractall(path) 152 except (tarfile.TarError, RuntimeError, KeyboardInterrupt): 153 if os.path.exists(path): 154 if os.path.isfile(path): 155 os.remove(path) 156 else: 157 shutil.rmtree(path) 158 raise 159 return True 160 return False 161 162 163@keras_export('keras.utils.get_file') 164def get_file(fname, 165 origin, 166 untar=False, 167 md5_hash=None, 168 file_hash=None, 169 cache_subdir='datasets', 170 hash_algorithm='auto', 171 extract=False, 172 archive_format='auto', 173 cache_dir=None): 174 """Downloads a file from a URL if it not already in the cache. 175 176 By default the file at the url `origin` is downloaded to the 177 cache_dir `~/.keras`, placed in the cache_subdir `datasets`, 178 and given the filename `fname`. The final location of a file 179 `example.txt` would therefore be `~/.keras/datasets/example.txt`. 180 181 Files in tar, tar.gz, tar.bz, and zip formats can also be extracted. 182 Passing a hash will verify the file after download. The command line 183 programs `shasum` and `sha256sum` can compute the hash. 184 185 Arguments: 186 fname: Name of the file. If an absolute path `/path/to/file.txt` is 187 specified the file will be saved at that location. 188 origin: Original URL of the file. 189 untar: Deprecated in favor of 'extract'. 190 boolean, whether the file should be decompressed 191 md5_hash: Deprecated in favor of 'file_hash'. 192 md5 hash of the file for verification 193 file_hash: The expected hash string of the file after download. 194 The sha256 and md5 hash algorithms are both supported. 195 cache_subdir: Subdirectory under the Keras cache dir where the file is 196 saved. If an absolute path `/path/to/folder` is 197 specified the file will be saved at that location. 198 hash_algorithm: Select the hash algorithm to verify the file. 199 options are 'md5', 'sha256', and 'auto'. 200 The default 'auto' detects the hash algorithm in use. 201 extract: True tries extracting the file as an Archive, like tar or zip. 202 archive_format: Archive format to try for extracting the file. 203 Options are 'auto', 'tar', 'zip', and None. 204 'tar' includes tar, tar.gz, and tar.bz files. 205 The default 'auto' is ['tar', 'zip']. 206 None or an empty list will return no matches found. 207 cache_dir: Location to store cached files, when None it 208 defaults to the [Keras 209 Directory](/faq/#where-is-the-keras-configuration-filed-stored). 210 211 Returns: 212 Path to the downloaded file 213 """ 214 if cache_dir is None: 215 cache_dir = os.path.join(os.path.expanduser('~'), '.keras') 216 if md5_hash is not None and file_hash is None: 217 file_hash = md5_hash 218 hash_algorithm = 'md5' 219 datadir_base = os.path.expanduser(cache_dir) 220 if not os.access(datadir_base, os.W_OK): 221 datadir_base = os.path.join('/tmp', '.keras') 222 datadir = os.path.join(datadir_base, cache_subdir) 223 _makedirs_exist_ok(datadir) 224 225 if untar: 226 untar_fpath = os.path.join(datadir, fname) 227 fpath = untar_fpath + '.tar.gz' 228 else: 229 fpath = os.path.join(datadir, fname) 230 231 download = False 232 if os.path.exists(fpath): 233 # File found; verify integrity if a hash was provided. 234 if file_hash is not None: 235 if not validate_file(fpath, file_hash, algorithm=hash_algorithm): 236 print('A local file was found, but it seems to be ' 237 'incomplete or outdated because the ' + hash_algorithm + 238 ' file hash does not match the original value of ' + file_hash + 239 ' so we will re-download the data.') 240 download = True 241 else: 242 download = True 243 244 if download: 245 print('Downloading data from', origin) 246 247 class ProgressTracker(object): 248 # Maintain progbar for the lifetime of download. 249 # This design was chosen for Python 2.7 compatibility. 250 progbar = None 251 252 def dl_progress(count, block_size, total_size): 253 if ProgressTracker.progbar is None: 254 if total_size == -1: 255 total_size = None 256 ProgressTracker.progbar = Progbar(total_size) 257 else: 258 ProgressTracker.progbar.update(count * block_size) 259 260 error_msg = 'URL fetch failure on {}: {} -- {}' 261 try: 262 try: 263 urlretrieve(origin, fpath, dl_progress) 264 except HTTPError as e: 265 raise Exception(error_msg.format(origin, e.code, e.msg)) 266 except URLError as e: 267 raise Exception(error_msg.format(origin, e.errno, e.reason)) 268 except (Exception, KeyboardInterrupt) as e: 269 if os.path.exists(fpath): 270 os.remove(fpath) 271 raise 272 ProgressTracker.progbar = None 273 274 if untar: 275 if not os.path.exists(untar_fpath): 276 _extract_archive(fpath, datadir, archive_format='tar') 277 return untar_fpath 278 279 if extract: 280 _extract_archive(fpath, datadir, archive_format) 281 282 return fpath 283 284 285def _makedirs_exist_ok(datadir): 286 if six.PY2: 287 # Python 2 doesn't have the exist_ok arg, so we try-except here. 288 try: 289 os.makedirs(datadir) 290 except OSError as e: 291 if e.errno != errno.EEXIST: 292 raise 293 else: 294 os.makedirs(datadir, exist_ok=True) # pylint: disable=unexpected-keyword-arg 295 296 297def _hash_file(fpath, algorithm='sha256', chunk_size=65535): 298 """Calculates a file sha256 or md5 hash. 299 300 Example: 301 302 ```python 303 _hash_file('/path/to/file.zip') 304 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' 305 ``` 306 307 Arguments: 308 fpath: path to the file being validated 309 algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'. 310 The default 'auto' detects the hash algorithm in use. 311 chunk_size: Bytes to read at a time, important for large files. 312 313 Returns: 314 The file hash 315 """ 316 if (algorithm == 'sha256') or (algorithm == 'auto' and len(hash) == 64): 317 hasher = hashlib.sha256() 318 else: 319 hasher = hashlib.md5() 320 321 with open(fpath, 'rb') as fpath_file: 322 for chunk in iter(lambda: fpath_file.read(chunk_size), b''): 323 hasher.update(chunk) 324 325 return hasher.hexdigest() 326 327 328def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535): 329 """Validates a file against a sha256 or md5 hash. 330 331 Arguments: 332 fpath: path to the file being validated 333 file_hash: The expected hash string of the file. 334 The sha256 and md5 hash algorithms are both supported. 335 algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'. 336 The default 'auto' detects the hash algorithm in use. 337 chunk_size: Bytes to read at a time, important for large files. 338 339 Returns: 340 Whether the file is valid 341 """ 342 if (algorithm == 'sha256') or (algorithm == 'auto' and len(file_hash) == 64): 343 hasher = 'sha256' 344 else: 345 hasher = 'md5' 346 347 if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash): 348 return True 349 else: 350 return False 351 352 353class ThreadsafeIter(object): 354 """Wrap an iterator with a lock and propagate exceptions to all threads.""" 355 356 def __init__(self, it): 357 self.it = it 358 self.lock = threading.Lock() 359 360 # After a generator throws an exception all subsequent next() calls raise a 361 # StopIteration Exception. This, however, presents an issue when mixing 362 # generators and threading because it means the order of retrieval need not 363 # match the order in which the generator was called. This can make it appear 364 # that a generator exited normally when in fact the terminating exception is 365 # just in a different thread. In order to provide thread safety, once 366 # self.it has thrown an exception we continue to throw the same exception. 367 self._exception = None 368 369 def __iter__(self): 370 return self 371 372 def __next__(self): 373 return self.next() 374 375 def next(self): 376 with self.lock: 377 if self._exception: 378 raise self._exception # pylint: disable=raising-bad-type 379 380 try: 381 return next(self.it) 382 except Exception as e: 383 self._exception = e 384 raise 385 386 387def threadsafe_generator(f): 388 389 @functools.wraps(f) 390 def g(*a, **kw): 391 return ThreadsafeIter(f(*a, **kw)) 392 393 return g 394 395 396@keras_export('keras.utils.Sequence') 397class Sequence(object): 398 """Base object for fitting to a sequence of data, such as a dataset. 399 400 Every `Sequence` must implement the `__getitem__` and the `__len__` methods. 401 If you want to modify your dataset between epochs you may implement 402 `on_epoch_end`. 403 The method `__getitem__` should return a complete batch. 404 405 Notes: 406 407 `Sequence` are a safer way to do multiprocessing. This structure guarantees 408 that the network will only train once 409 on each sample per epoch which is not the case with generators. 410 411 Examples: 412 413 ```python 414 from skimage.io import imread 415 from skimage.transform import resize 416 import numpy as np 417 import math 418 419 # Here, `x_set` is list of path to the images 420 # and `y_set` are the associated classes. 421 422 class CIFAR10Sequence(Sequence): 423 424 def __init__(self, x_set, y_set, batch_size): 425 self.x, self.y = x_set, y_set 426 self.batch_size = batch_size 427 428 def __len__(self): 429 return math.ceil(len(self.x) / self.batch_size) 430 431 def __getitem__(self, idx): 432 batch_x = self.x[idx * self.batch_size:(idx + 1) * 433 self.batch_size] 434 batch_y = self.y[idx * self.batch_size:(idx + 1) * 435 self.batch_size] 436 437 return np.array([ 438 resize(imread(file_name), (200, 200)) 439 for file_name in batch_x]), np.array(batch_y) 440 ``` 441 """ 442 443 @abstractmethod 444 def __getitem__(self, index): 445 """Gets batch at position `index`. 446 447 Arguments: 448 index: position of the batch in the Sequence. 449 450 Returns: 451 A batch 452 """ 453 raise NotImplementedError 454 455 @abstractmethod 456 def __len__(self): 457 """Number of batch in the Sequence. 458 459 Returns: 460 The number of batches in the Sequence. 461 """ 462 raise NotImplementedError 463 464 def on_epoch_end(self): 465 """Method called at the end of every epoch. 466 """ 467 pass 468 469 def __iter__(self): 470 """Create a generator that iterate over the Sequence.""" 471 for item in (self[i] for i in range(len(self))): 472 yield item 473 474 475def iter_sequence_infinite(seq): 476 """Iterates indefinitely over a Sequence. 477 478 Arguments: 479 seq: Sequence instance. 480 481 Yields: 482 Batches of data from the Sequence. 483 """ 484 while True: 485 for item in seq: 486 yield item 487 488 489# Global variables to be shared across processes 490_SHARED_SEQUENCES = {} 491# We use a Value to provide unique id to different processes. 492_SEQUENCE_COUNTER = None 493 494 495# Because multiprocessing pools are inherently unsafe, starting from a clean 496# state can be essential to avoiding deadlocks. In order to accomplish this, we 497# need to be able to check on the status of Pools that we create. 498_DATA_POOLS = weakref.WeakSet() 499_WORKER_ID_QUEUE = None # Only created if needed. 500_WORKER_IDS = set() 501_FORCE_THREADPOOL = False 502_FORCE_THREADPOOL_LOCK = threading.RLock() 503 504 505def dont_use_multiprocessing_pool(f): 506 @functools.wraps(f) 507 def wrapped(*args, **kwargs): 508 with _FORCE_THREADPOOL_LOCK: 509 global _FORCE_THREADPOOL 510 old_force_threadpool, _FORCE_THREADPOOL = _FORCE_THREADPOOL, True 511 out = f(*args, **kwargs) 512 _FORCE_THREADPOOL = old_force_threadpool 513 return out 514 return wrapped 515 516 517def get_pool_class(use_multiprocessing): 518 global _FORCE_THREADPOOL 519 if not use_multiprocessing or _FORCE_THREADPOOL: 520 return multiprocessing.dummy.Pool # ThreadPool 521 logging.warning( 522 'multiprocessing can interact badly with TensorFlow, causing ' 523 'nondeterministic deadlocks. For high performance data pipelines tf.data ' 524 'is recommended.') 525 return multiprocessing.Pool 526 527 528def get_worker_id_queue(): 529 """Lazily create the queue to track worker ids.""" 530 global _WORKER_ID_QUEUE 531 if _WORKER_ID_QUEUE is None: 532 _WORKER_ID_QUEUE = multiprocessing.Queue() 533 return _WORKER_ID_QUEUE 534 535 536def init_pool(seqs): 537 global _SHARED_SEQUENCES 538 _SHARED_SEQUENCES = seqs 539 540 541@keras_export('keras.experimental.terminate_keras_multiprocessing_pools') 542def terminate_keras_multiprocessing_pools(grace_period=0.1, use_sigkill=False): 543 """Destroy Keras' multiprocessing pools to prevent deadlocks. 544 545 In general multiprocessing.Pool can interact quite badly with other, seemingly 546 unrelated, parts of a codebase due to Pool's reliance on fork. This method 547 cleans up all pools which are known to belong to Keras (and thus can be safely 548 terminated). 549 550 Args: 551 grace_period: Time (in seconds) to wait for process cleanup to propagate. 552 use_sigkill: Boolean of whether or not to perform a cleanup pass using 553 SIGKILL. 554 555 Returns: 556 A list of human readable strings describing all issues encountered. It is up 557 to the caller to decide whether to treat this as an error condition. 558 """ 559 errors = [] 560 561 # First cleanup the pools spawned by Keras. If we start killing workers and 562 # a parent pool is still alive it will just spawn replacements which we don't 563 # want. 564 gc.collect() 565 for pool in _DATA_POOLS: 566 pool.close() 567 pool.terminate() 568 # We do not join the pool, because that would wait forever if a worker 569 # refused to exit. 570 571 # Finally, delete our reference to the pool so that we do not block garbage 572 # collection. 573 del pool 574 575 # If there were any pools, sleep for a small grace period to allow everything 576 # to finalize. 577 if _DATA_POOLS: 578 time.sleep(grace_period) 579 580 # Now we kill any workers which are still alive. However we must compare 581 # the worker identifier to the set of identifiers which are known to have been 582 # spawned by pools belonging to Keras to avoid deleting unrelated workers. 583 # First we call the .terminate() method of a worker, and then if it still 584 # persists we directly send a signal to the process. Certain worker tasks may 585 # be able to gracefully handle shutdown, so we send a SIGTERM and then 586 # optionally follow up with a SIGKILL. 587 visited_workers = set() 588 cleanup_passes = ['.terminate', 'SIGTERM'] 589 if use_sigkill: 590 cleanup_passes.append('SIGKILL') 591 cleanup_passes.append('log') 592 593 for cleanup_pass in cleanup_passes: 594 while True: 595 # In rare cases, queue.qsize() overestimates the number of elements. This 596 # loop is designed to be more robust. 597 try: 598 _WORKER_IDS.add(get_worker_id_queue().get_nowait()) 599 except queue.Empty: 600 break 601 602 gc.collect() 603 workers_terminated_this_pass = False 604 for worker in multiprocessing.active_children(): 605 ident = worker.ident 606 if ident in _WORKER_IDS and worker.is_alive(): 607 try: 608 if cleanup_pass == '.terminate': 609 # First we ask nicely. 610 worker.terminate() 611 worker.join(timeout=grace_period) 612 visited_workers.add(ident) 613 workers_terminated_this_pass = True 614 elif cleanup_pass in ('SIGTERM', 'SIGKILL'): 615 # Then we ask increasingly tersely. 616 os.kill(worker.pid, signal.SIGKILL if cleanup_pass == 'SIGKILL' 617 else signal.SIGTERM) 618 workers_terminated_this_pass = True 619 620 elif cleanup_pass == 'log': 621 # And finally we give up and log the failure. 622 errors.append('worker still alive: {}, pid={}, hash={}' 623 .format(worker.name, worker.pid, hash(worker))) 624 625 except OSError: 626 # Worker exited since the start of this loop. 627 pass 628 629 if workers_terminated_this_pass: 630 # There can be a small propagation delay between worker destruction and 631 # workers reporting False for is_alive and no longer appearing in the 632 # list of active children. Once again, we sleep for a small grace period. 633 # This prevents false positives from workers which are simply still in the 634 # process of spinning down. 635 time.sleep(grace_period) 636 637 # Finally we remove the visited worker ids to handle the edge case that a 638 # pid is reused. 639 _WORKER_IDS.difference_update(visited_workers) 640 641 gc.collect() 642 for pool in _DATA_POOLS: 643 errors.append('pool still exists: {}, hash={}'.format(pool, hash(pool))) 644 645 return errors 646 647 648def get_index(uid, i): 649 """Get the value from the Sequence `uid` at index `i`. 650 651 To allow multiple Sequences to be used at the same time, we use `uid` to 652 get a specific one. A single Sequence would cause the validation to 653 overwrite the training Sequence. 654 655 Arguments: 656 uid: int, Sequence identifier 657 i: index 658 659 Returns: 660 The value at index `i`. 661 """ 662 return _SHARED_SEQUENCES[uid][i] 663 664 665@keras_export('keras.utils.SequenceEnqueuer') 666class SequenceEnqueuer(object): 667 """Base class to enqueue inputs. 668 669 The task of an Enqueuer is to use parallelism to speed up preprocessing. 670 This is done with processes or threads. 671 672 Example: 673 674 ```python 675 enqueuer = SequenceEnqueuer(...) 676 enqueuer.start() 677 datas = enqueuer.get() 678 for data in datas: 679 # Use the inputs; training, evaluating, predicting. 680 # ... stop sometime. 681 enqueuer.close() 682 ``` 683 684 The `enqueuer.get()` should be an infinite stream of datas. 685 """ 686 687 def __init__(self, sequence, 688 use_multiprocessing=False): 689 self.sequence = sequence 690 self.use_multiprocessing = use_multiprocessing 691 692 global _SEQUENCE_COUNTER 693 if _SEQUENCE_COUNTER is None: 694 try: 695 _SEQUENCE_COUNTER = multiprocessing.Value('i', 0) 696 except OSError: 697 # In this case the OS does not allow us to use 698 # multiprocessing. We resort to an int 699 # for enqueuer indexing. 700 _SEQUENCE_COUNTER = 0 701 702 if isinstance(_SEQUENCE_COUNTER, int): 703 self.uid = _SEQUENCE_COUNTER 704 _SEQUENCE_COUNTER += 1 705 else: 706 # Doing Multiprocessing.Value += x is not process-safe. 707 with _SEQUENCE_COUNTER.get_lock(): 708 self.uid = _SEQUENCE_COUNTER.value 709 _SEQUENCE_COUNTER.value += 1 710 711 self.workers = 0 712 self.executor_fn = None 713 self.queue = None 714 self.run_thread = None 715 self.stop_signal = None 716 717 def is_running(self): 718 return self.stop_signal is not None and not self.stop_signal.is_set() 719 720 def start(self, workers=1, max_queue_size=10): 721 """Starts the handler's workers. 722 723 Arguments: 724 workers: Number of workers. 725 max_queue_size: queue size 726 (when full, workers could block on `put()`) 727 """ 728 if self.use_multiprocessing: 729 self.executor_fn = self._get_executor_init(workers) 730 else: 731 # We do not need the init since it's threads. 732 self.executor_fn = lambda _: get_pool_class(False)(workers) 733 self.workers = workers 734 self.queue = queue.Queue(max_queue_size) 735 self.stop_signal = threading.Event() 736 self.run_thread = threading.Thread(target=self._run) 737 self.run_thread.daemon = True 738 self.run_thread.start() 739 740 def _send_sequence(self): 741 """Sends current Iterable to all workers.""" 742 # For new processes that may spawn 743 _SHARED_SEQUENCES[self.uid] = self.sequence 744 745 def stop(self, timeout=None): 746 """Stops running threads and wait for them to exit, if necessary. 747 748 Should be called by the same thread which called `start()`. 749 750 Arguments: 751 timeout: maximum time to wait on `thread.join()` 752 """ 753 self.stop_signal.set() 754 with self.queue.mutex: 755 self.queue.queue.clear() 756 self.queue.unfinished_tasks = 0 757 self.queue.not_full.notify() 758 self.run_thread.join(timeout) 759 _SHARED_SEQUENCES[self.uid] = None 760 761 def __del__(self): 762 if self.is_running(): 763 self.stop() 764 765 @abstractmethod 766 def _run(self): 767 """Submits request to the executor and queue the `Future` objects.""" 768 raise NotImplementedError 769 770 @abstractmethod 771 def _get_executor_init(self, workers): 772 """Gets the Pool initializer for multiprocessing. 773 774 Arguments: 775 workers: Number of workers. 776 777 Returns: 778 Function, a Function to initialize the pool 779 """ 780 raise NotImplementedError 781 782 @abstractmethod 783 def get(self): 784 """Creates a generator to extract data from the queue. 785 786 Skip the data if it is `None`. 787 # Returns 788 Generator yielding tuples `(inputs, targets)` 789 or `(inputs, targets, sample_weights)`. 790 """ 791 raise NotImplementedError 792 793 794@keras_export('keras.utils.OrderedEnqueuer') 795class OrderedEnqueuer(SequenceEnqueuer): 796 """Builds a Enqueuer from a Sequence. 797 798 Used in `fit_generator`, `evaluate_generator`, `predict_generator`. 799 800 Arguments: 801 sequence: A `tf.keras.utils.data_utils.Sequence` object. 802 use_multiprocessing: use multiprocessing if True, otherwise threading 803 shuffle: whether to shuffle the data at the beginning of each epoch 804 """ 805 806 def __init__(self, sequence, use_multiprocessing=False, shuffle=False): 807 super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing) 808 self.shuffle = shuffle 809 810 def _get_executor_init(self, workers): 811 """Gets the Pool initializer for multiprocessing. 812 813 Arguments: 814 workers: Number of workers. 815 816 Returns: 817 Function, a Function to initialize the pool 818 """ 819 def pool_fn(seqs): 820 pool = get_pool_class(True)( 821 workers, initializer=init_pool_generator, 822 initargs=(seqs, None, get_worker_id_queue())) 823 _DATA_POOLS.add(pool) 824 return pool 825 826 return pool_fn 827 828 def _wait_queue(self): 829 """Wait for the queue to be empty.""" 830 while True: 831 time.sleep(0.1) 832 if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set(): 833 return 834 835 def _run(self): 836 """Submits request to the executor and queue the `Future` objects.""" 837 sequence = list(range(len(self.sequence))) 838 self._send_sequence() # Share the initial sequence 839 while True: 840 if self.shuffle: 841 random.shuffle(sequence) 842 843 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: 844 for i in sequence: 845 if self.stop_signal.is_set(): 846 return 847 848 self.queue.put( 849 executor.apply_async(get_index, (self.uid, i)), block=True) 850 851 # Done with the current epoch, waiting for the final batches 852 self._wait_queue() 853 854 if self.stop_signal.is_set(): 855 # We're done 856 return 857 858 # Call the internal on epoch end. 859 self.sequence.on_epoch_end() 860 self._send_sequence() # Update the pool 861 862 def get(self): 863 """Creates a generator to extract data from the queue. 864 865 Skip the data if it is `None`. 866 867 Yields: 868 The next element in the queue, i.e. a tuple 869 `(inputs, targets)` or 870 `(inputs, targets, sample_weights)`. 871 """ 872 try: 873 while self.is_running(): 874 inputs = self.queue.get(block=True).get() 875 self.queue.task_done() 876 if inputs is not None: 877 yield inputs 878 except Exception: # pylint: disable=broad-except 879 self.stop() 880 six.reraise(*sys.exc_info()) 881 882 883def init_pool_generator(gens, random_seed=None, id_queue=None): 884 """Initializer function for pool workers. 885 886 Args: 887 gens: State which should be made available to worker processes. 888 random_seed: An optional value with which to seed child processes. 889 id_queue: A multiprocessing Queue of worker ids. This is used to indicate 890 that a worker process was created by Keras and can be terminated using 891 the cleanup_all_keras_forkpools utility. 892 """ 893 global _SHARED_SEQUENCES 894 _SHARED_SEQUENCES = gens 895 896 worker_proc = multiprocessing.current_process() 897 898 # name isn't used for anything, but setting a more descriptive name is helpful 899 # when diagnosing orphaned processes. 900 worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name) 901 902 if random_seed is not None: 903 np.random.seed(random_seed + worker_proc.ident) 904 905 if id_queue is not None: 906 # If a worker dies during init, the pool will just create a replacement. 907 id_queue.put(worker_proc.ident, block=True, timeout=0.1) 908 909 910def next_sample(uid): 911 """Gets the next value from the generator `uid`. 912 913 To allow multiple generators to be used at the same time, we use `uid` to 914 get a specific one. A single generator would cause the validation to 915 overwrite the training generator. 916 917 Arguments: 918 uid: int, generator identifier 919 920 Returns: 921 The next value of generator `uid`. 922 """ 923 return six.next(_SHARED_SEQUENCES[uid]) 924 925 926@keras_export('keras.utils.GeneratorEnqueuer') 927class GeneratorEnqueuer(SequenceEnqueuer): 928 """Builds a queue out of a data generator. 929 930 The provided generator can be finite in which case the class will throw 931 a `StopIteration` exception. 932 933 Used in `fit_generator`, `evaluate_generator`, `predict_generator`. 934 935 Arguments: 936 generator: a generator function which yields data 937 use_multiprocessing: use multiprocessing if True, otherwise threading 938 wait_time: time to sleep in-between calls to `put()` 939 random_seed: Initial seed for workers, 940 will be incremented by one for each worker. 941 """ 942 943 def __init__(self, sequence, 944 use_multiprocessing=False, 945 random_seed=None): 946 super(GeneratorEnqueuer, self).__init__(sequence, use_multiprocessing) 947 self.random_seed = random_seed 948 949 def _get_executor_init(self, workers): 950 """Gets the Pool initializer for multiprocessing. 951 952 Arguments: 953 workers: Number of works. 954 955 Returns: 956 A Function to initialize the pool 957 """ 958 def pool_fn(seqs): 959 pool = get_pool_class(True)( 960 workers, initializer=init_pool_generator, 961 initargs=(seqs, self.random_seed, get_worker_id_queue())) 962 _DATA_POOLS.add(pool) 963 return pool 964 return pool_fn 965 966 def _run(self): 967 """Submits request to the executor and queue the `Future` objects.""" 968 self._send_sequence() # Share the initial generator 969 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: 970 while True: 971 if self.stop_signal.is_set(): 972 return 973 974 self.queue.put( 975 executor.apply_async(next_sample, (self.uid,)), block=True) 976 977 def get(self): 978 """Creates a generator to extract data from the queue. 979 980 Skip the data if it is `None`. 981 982 Yields: 983 The next element in the queue, i.e. a tuple 984 `(inputs, targets)` or 985 `(inputs, targets, sample_weights)`. 986 """ 987 try: 988 while self.is_running(): 989 inputs = self.queue.get(block=True).get() 990 self.queue.task_done() 991 if inputs is not None: 992 yield inputs 993 except StopIteration: 994 # Special case for finite generators 995 last_ones = [] 996 while self.queue.qsize() > 0: 997 last_ones.append(self.queue.get(block=True)) 998 # Wait for them to complete 999 for f in last_ones: 1000 f.wait() 1001 # Keep the good ones 1002 last_ones = [future.get() for future in last_ones if future.successful()] 1003 for inputs in last_ones: 1004 if inputs is not None: 1005 yield inputs 1006 except Exception as e: # pylint: disable=broad-except 1007 self.stop() 1008 if 'generator already executing' in str(e): 1009 raise RuntimeError( 1010 'Your generator is NOT thread-safe. ' 1011 'Keras requires a thread-safe generator when ' 1012 '`use_multiprocessing=False, workers > 1`. ') 1013 six.reraise(*sys.exc_info()) 1014