1# Copyright 2019-2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16This file contains basic classes that help users do flexible dataset loading. 17You can define your own dataset loading class, and use GeneratorDataset to help load data. 18After declaring the dataset object, you can further apply dataset operations 19(e.g. filter, skip, concat, map, batch) on it. 20""" 21import builtins 22import copy 23import errno 24import math 25import os 26import signal 27import time 28import multiprocessing 29from multiprocessing.util import Finalize 30import queue 31from functools import partial 32import subprocess 33import threading 34import weakref 35import platform 36import psutil 37import numpy as np 38 39import mindspore._c_dataengine as cde 40 41from mindspore.common import Tensor 42from mindspore import log as logger 43 44from .datasets import UnionBaseDataset, MappableDataset, Schema, to_list, _PythonMultiprocessing, _check_shm_usage 45from . import samplers 46from .queue import _SharedQueue 47from .validators import check_generatordataset, check_numpyslicesdataset, check_paddeddataset 48from ..core.config import get_enable_shared_mem, get_prefetch_size, get_multiprocessing_timeout_interval, \ 49 get_enable_watchdog, get_debug_mode 50from ..core.datatypes import mstypelist_to_detypelist 51from ..core.py_util_helpers import ExceptionHandler 52from ..transforms import transforms 53 54 55def _iter_fn(dataset, num_samples): 56 """ 57 Generator function wrapper for iterable dataset. 58 """ 59 if num_samples is not None and num_samples != 0: 60 ds_iter = iter(dataset) 61 for _ in range(num_samples): 62 try: 63 val = next(ds_iter) 64 except StopIteration: 65 return 66 # convert output tensors to ndarrays 67 yield _convert_row(val) 68 else: 69 for val in dataset: 70 # convert output tensors to ndarrays 71 yield _convert_row(val) 72 73 74def _generator_fn(generator, num_samples): 75 """ 76 Generator function wrapper for generator function dataset. 77 """ 78 if num_samples is not None and num_samples != 0: 79 gen_iter = generator() 80 for _ in range(num_samples): 81 try: 82 val = next(gen_iter) 83 except StopIteration: 84 return 85 yield _convert_row(val) 86 else: 87 gen_iter = generator() 88 for val in gen_iter: 89 yield _convert_row(val) 90 91 92def _cpp_sampler_fn(sample_ids, dataset): 93 """ 94 Generator function wrapper for mappable dataset with cpp sampler. 95 """ 96 if not isinstance(sample_ids, np.ndarray): 97 raise RuntimeError("Sample IDs are not in a numpy array.") 98 if sample_ids.size == 0: 99 raise RuntimeError("Sampler passed an empty sample IDs list.") 100 101 for i in sample_ids: 102 val = dataset[i] 103 # convert output tensors to ndarrays 104 yield _convert_row(val) 105 106 107def _cpp_sampler_fn_mp(sample_ids, sample_fn): 108 """ 109 Multiprocessing generator function wrapper for mappable dataset with cpp sampler. 110 """ 111 if not isinstance(sample_ids, np.ndarray): 112 raise RuntimeError("Sample IDs are not in a numpy array.") 113 if sample_ids.size == 0: 114 raise RuntimeError("Sampler passed an empty sample IDs list.") 115 116 return sample_fn.process(sample_ids) 117 118 119def _fill_worker_indices(workers, indices, idx_cursor, worker_to_quit): 120 """ 121 Worker index queue filler, fill worker index queue in round robin order or QUIT flag. 122 """ 123 num_worker = len(workers) 124 if idx_cursor < len(indices): 125 while idx_cursor < len(indices): 126 try: 127 workers[idx_cursor % num_worker].put(indices[idx_cursor]) 128 idx_cursor += 1 129 except queue.Full: 130 break 131 else: 132 for i in range(num_worker): 133 # just put only one QUIT flag to the sub-thread / sub-process 134 if str(i) not in worker_to_quit: 135 try: 136 workers[i].put("QUIT") 137 worker_to_quit[str(i)] = "QUIT" 138 except queue.Full: 139 continue 140 return idx_cursor, worker_to_quit 141 142 143def _convert_row(row): 144 """ 145 Convert Op return value to numpy, or keep as a dict (if already a dict) 146 """ 147 148 # convert single item to np.array 149 prim_type = (int, float, str, bytes, np.ndarray, Tensor, np.number, np.bool_) 150 if isinstance(row, prim_type): 151 if isinstance(row, Tensor): # mindspore.Tensor 152 item = row.asnumpy() 153 else: 154 item = np.array(row, copy=False) 155 if item.dtype == 'object': 156 raise TypeError("Data type of the input or its converted Numpy array is expected to be " \ 157 "int or float or str, but got {}.".format(item.dtype)) 158 return tuple([item]) 159 160 if isinstance(row, dict): 161 return tuple([row]) 162 163 value = [] 164 # convert each item to np.array 165 idx = 0 166 for x in row: 167 idx += 1 168 if isinstance(x, Tensor): # mindspore.Tensor 169 value.append(x.asnumpy()) 170 elif isinstance(x, dict): 171 value.append(x) 172 else: 173 item = np.array(x, copy=False) 174 if item.dtype == 'object': 175 raise TypeError("Data type of {}th item of the input or its converted Numpy array is expected to be " \ 176 "int or float or str, but got {}.".format(idx, item.dtype)) 177 value.append(item) 178 return tuple(value) 179 180 181class SamplerFn: 182 """ 183 Multiprocessing or multithread generator function wrapper master process. 184 """ 185 186 def __init__(self, dataset, num_worker, multi_process, max_rowsize): 187 self.workers = [] 188 self.dataset = dataset 189 self.num_worker = num_worker 190 self.multi_process = multi_process 191 self.max_rowsize = max_rowsize 192 self.need_join = False 193 self.ppid = os.getpid() 194 self.pids = [] 195 self.check_interval = get_multiprocessing_timeout_interval() # the interval of check queue's size 196 self._final_join = True 197 198 # Event for end of epoch 199 if multi_process is True: 200 try: 201 self.eof = multiprocessing.Event() 202 except Exception: 203 raise RuntimeError("Init multiprocessing.Event() failed, This might be caused by insufficient shm," 204 + " and the recommended shm size is at least 5 GB.") 205 else: 206 self.eof = threading.Event() 207 # Create workers 208 209 # get default queue size and adjust queuesize per worker if there are large # workers 210 queue_size = get_prefetch_size() 211 queue_size = min(queue_size, queue_size * 4 // num_worker) 212 queue_size = max(2, queue_size) 213 214 if multi_process and get_enable_shared_mem(): 215 # generator dataset use idx_queue and res_queue to transfer data between main and subprocess 216 # idx_queue is used multiprocess.Queue which is not shared memory, so it's size is 0. 217 # res_queue is used shared memory, so it' size is max_rowsize which is defined by user. 218 _check_shm_usage(num_worker, queue_size, 0, max_rowsize) 219 self.count = multiprocessing.Value('i', 0) 220 for worker_id in range(num_worker): 221 if multi_process is True: 222 try: 223 worker = _GeneratorWorkerMp(dataset, self.eof, max_rowsize, queue_size, self.ppid, self.count, 224 worker_id) 225 worker.daemon = True 226 # When multi processes fork a subprocess, the lock of the main process is copied to the subprocess, 227 # which may cause deadlock. Therefore, the subprocess startup is performed in the initialization 228 # phase. In this phase, the main process is not locked. 229 worker.start() 230 except OSError as e: 231 if e.errno == errno.EMFILE: 232 raise RuntimeError("Failed to launch multiprocessing of GeneratorDataset: " 233 "Too many open files. Please check if `num_parallel_workers` " 234 "is set too large, or you are creating iterators multiple times. " 235 "You can also increase the limit using `ulimit -n` in the shell " 236 "to avoid this error.") 237 raise 238 except Exception as e: 239 raise RuntimeError("Failed to launch multiprocessing of GeneratorDataset: {0}".format(e)) 240 self.pids.append(worker.pid) 241 self.need_join = True 242 else: 243 worker = _GeneratorWorkerMt(dataset, self.eof, worker_id) 244 worker.daemon = True 245 self.workers.append(worker) 246 self._launch_cleanup_worker(multi_process=multi_process) 247 248 def _interval_log(self, i, start_time, wait_count): 249 cost_time = int(time.time()) - start_time 250 if cost_time / self.check_interval >= wait_count: 251 wait_count += 1 252 self._log_stuck_warning(self.workers[i % self.num_worker], cost_time) 253 return wait_count 254 255 def process(self, indices): 256 """ 257 The main process, start the child process or child thread, and fill the index queue. 258 Get the result and return. 259 """ 260 for w in self.workers: 261 # Check whether the queue of the subprocess is empty. 262 if not w.queue_empty(): 263 # in failover reset scenario the QUIT flag should be pop first 264 while w.idx_queue.qsize() > 0: 265 try: 266 result = w.idx_queue.get(timeout=1) 267 if result != "QUIT": 268 raise Exception("The queue of the subprocess is not empty.") 269 except queue.Empty: 270 continue 271 # Start all workers 272 if not w.is_alive(): 273 w.start() 274 275 # Fill initial index queues 276 idx_cursor = 0 277 # worker to quit 278 worker_to_quit = {} 279 idx_cursor, worker_to_quit = _fill_worker_indices(self.workers, indices, idx_cursor, worker_to_quit) 280 281 # Fetch results 282 for i in range(len(indices)): 283 if self.eof.is_set(): 284 self._stop_subprocess() 285 return 286 if self.multi_process is True and not psutil.pid_exists(self.workers[i % self.num_worker].pid): 287 self._stop_subprocess() 288 return 289 # Fetch result and put index 290 try: 291 # To avoid get timeout from queue, check the res_queue size. 292 start_time = int(time.time()) 293 wait_count = 1 294 while self.workers[i % self.num_worker].res_queue.empty(): 295 if self.eof.is_set(): 296 logger.warning("Generator receives a termination signal, stop waiting for data " 297 "from subprocess.") 298 self._stop_subprocess() 299 return 300 time.sleep(0.1) 301 wait_count = self._interval_log(i, start_time, wait_count) 302 result = self.workers[i % self.num_worker].get() 303 # Because there is no need to copy when creating Tensors in the C++layer, it reduces the time 304 # from np.ndarray to C++Tensor creation. However, when using shared memory in multiple processes, 305 # the address of the shared memory will always be passed to subsequent nodes in the dataset pipeline, 306 # and the shared memory will also be written by the current node, causing dirty data to be accessed 307 # by subsequent nodes in the pipeline. So make a memory copy here to solve the problem of 308 # shared memory being contaminated. 309 if self.multi_process is True and get_enable_shared_mem(): 310 result = copy.deepcopy(result) 311 if isinstance(result, ExceptionHandler): 312 result.reraise() 313 except queue.Empty: 314 self._stop_subprocess() 315 raise Exception("Generator worker process timeout.") 316 except KeyboardInterrupt: 317 self._stop_subprocess() 318 raise Exception("Generator worker receives KeyboardInterrupt.") 319 if self.eof.is_set(): 320 self._stop_subprocess() 321 return 322 323 idx_cursor, worker_to_quit = _fill_worker_indices(self.workers, indices, idx_cursor, worker_to_quit) 324 325 yield _convert_row(result) 326 327 def _log_stuck_warning(self, worker, waiting_time): 328 """ 329 Log warning of the stuck worker, containing the worker ID, waiting time and 330 the current stack (if py-spy installed). 331 332 Args: 333 worker (Union[threading.Thread, multiprocessing.Process]): The worker instance. 334 waiting_time (int): The waiting time for getting data from the worker. 335 """ 336 if self.multi_process: 337 stuck_worker_id = worker.pid 338 worker_type = "process" 339 stuck_pid = stuck_worker_id 340 else: 341 if hasattr(worker, "native_id"): 342 # only supported since Python 3.8 343 stuck_worker_id = worker.native_id 344 else: 345 stuck_worker_id = worker.ident 346 worker_type = "thread" 347 stuck_pid = os.getpid() # get the process ID of the stuck thread 348 warning_message = "Has been waiting for data from Generator worker {0} ID '{1}' " \ 349 "for more than {2} seconds. Please check if the user defined " \ 350 "dataset of GeneratorDataset has a dead loop, or is processing " \ 351 "too slowly. ".format(worker_type, stuck_worker_id, waiting_time) 352 install_status, _ = subprocess.getstatusoutput("py-spy --version") 353 if install_status == 0: 354 stack = subprocess.getoutput("py-spy dump -p {}".format(stuck_pid)) 355 warning_message += "Below is the stack of this worker:\n{0}\n".format(stack) 356 else: 357 warning_message += "You can install py-spy via `pip install py-spy`, then " \ 358 "stop and rerun your script to get the current stack. " 359 warning_message += "If it is not a problem, you can adjust the printing frequency of this log via " \ 360 "the `mindspore.dataset.config.set_multiprocessing_timeout_interval` interface." 361 logger.warning(warning_message) 362 363 def _launch_cleanup_worker(self, multi_process): 364 """ 365 We need a extra thread and process if main process or subprocess was killed. 366 367 Args: 368 multi_process: Whether use multiprocess. 369 """ 370 if multi_process is True and platform.system().lower() != 'windows': 371 _clean_worker_func = _PythonMultiprocessing._clean_process # pylint: disable=W0212 372 self.cleaning_process = multiprocessing.Process(target=_clean_worker_func, 373 name="GeneratorCleanProcess", 374 args=(self.ppid, self.workers, self.eof)) 375 self.cleaning_process.daemon = True 376 self.cleaning_process.start() 377 378 if get_enable_watchdog(): 379 self.eot = threading.Event() 380 self.watch_dog = threading.Thread(target=_PythonMultiprocessing._watch_dog, # pylint: disable=W0212 381 name="GeneratorWatchDog", 382 args=(self.eot, self.workers + [self.cleaning_process])) 383 self.watch_dog.daemon = True 384 self.watch_dog.start() 385 386 if self._final_join is True: 387 self._jointhread = Finalize( 388 self.watch_dog, self._finalize_join, 389 args=(weakref.ref(self.watch_dog), self.eot), 390 exitpriority=-5 391 ) 392 393 def _stop_subprocess(self): 394 """Only the main process can call join.""" 395 if self.need_join is True and self.ppid == os.getpid(): 396 if hasattr(self, 'eof') and self.eof is not None: 397 self.eof.set() 398 # close the watch dog first 399 self._abort_watchdog() 400 self.need_join = False 401 for w in self.workers: 402 if self.multi_process is True and hasattr(w, '_closed') and w._closed is False: # pylint: disable=W0212 403 try: 404 # del the queue first 405 del w.res_queue 406 del w.idx_queue 407 408 # let the quit event notify the worker process to exit 409 w.join(timeout=5) 410 if w.is_alive(): 411 # if the worker process did not exit, it may hang, try to terminate it 412 w.terminate() 413 w.close() 414 except Exception: # pylint: disable=W0703 415 # Block all errors when join 416 continue 417 418 # release the file descriptor handle 419 check_interval = get_multiprocessing_timeout_interval() 420 for w in self.workers: 421 try: 422 subprocess_file_descriptor = w.sentinel 423 st = time.time() 424 while _PythonMultiprocessing.is_process_alive(w.pid): 425 time.sleep(0.01) # sleep 10ms, waiting for the subprocess exit 426 if time.time() - st > check_interval: 427 logger.warning("Waiting for the subprocess worker [{}] to exit.".format(w.pid)) 428 st += check_interval 429 except ValueError as e: 430 if "process object is closed" in str(e): 431 continue 432 raise e 433 try: 434 if w.is_alive(): 435 os.close(subprocess_file_descriptor) 436 except OSError as e: 437 # Maybe the file descriptor had been released, so ignore the 'Bad file descriptor' 438 if "Bad file descriptor" not in str(e): 439 raise e 440 441 self.workers.clear() 442 self.workers = None 443 444 def _abort_watchdog(self): 445 """Let watchdog quit.""" 446 if hasattr(self, 'eot') and self.eot is not None and not self.eot.is_set(): 447 self.eot.set() 448 if hasattr(self, 'cleaning_process') and self.cleaning_process is not None: 449 # let the quit event notify the cleaning process to exit 450 self.cleaning_process.join(timeout=5) 451 if self.cleaning_process.is_alive(): 452 # if the cleaning process did not exit, it may hang, try to terminate it 453 _PythonMultiprocessing._terminate_processes([self.cleaning_process]) # pylint: disable=W0212 454 del self.cleaning_process 455 if hasattr(self, 'count'): 456 del self.count 457 458 @classmethod 459 def _finalize_join(cls, twr, eot): 460 thread = twr() 461 if thread is not None: 462 if eot is not None and not eot.is_set(): 463 eot.set() 464 thread.join() 465 466 def __del__(self): 467 try: 468 self._stop_subprocess() 469 except TypeError: 470 pass 471 472 def __deepcopy__(self, memodict, exclude=()): 473 self.__init__(self.dataset, self.num_worker, self.multi_process, self.max_rowsize) 474 475 476def _subprocess_handle(eof, signum, frame): 477 threading.Thread(target=eof.set()).start() 478 479 480def _ignore_sigint(is_multiprocessing): 481 """ 482 We need to ignore sigint signal here so subprocesses can exit normally and clear. 483 """ 484 if is_multiprocessing: 485 signal.signal(signal.SIGINT, signal.SIG_IGN) 486 487 488def _main_process_already_exit(eof, is_multiprocessing, idx_queue, result_queue, ppid): 489 """ 490 Judge whether main process already exit. 491 """ 492 if eof.is_set() or (is_multiprocessing and platform.system().lower() != 'windows' and 493 not _PythonMultiprocessing.is_process_alive(ppid)): 494 if is_multiprocessing: 495 idx_queue.cancel_join_thread() 496 result_queue.cancel_join_thread() 497 return True 498 return False 499 500 501def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing, ppid=-1): 502 """ 503 Multithread or multiprocess generator worker process loop. 504 """ 505 if is_multiprocessing: 506 result_queue.cancel_join_thread() # Ensure that the process does not hung when exiting 507 signal.signal(signal.SIGTERM, partial(_subprocess_handle, eof)) 508 while not eof.is_set(): 509 _ignore_sigint(is_multiprocessing=is_multiprocessing) 510 511 # Fetch index, block 512 try: 513 idx = idx_queue.get(timeout=1) 514 except queue.Empty: 515 if _main_process_already_exit(eof, is_multiprocessing, idx_queue, result_queue, ppid) is True: 516 del idx_queue 517 del result_queue 518 return 519 # If end-of-file (eof) is not set, continue to get data from idx_queue 520 continue 521 if idx == "QUIT": 522 # all the data had been processed, so we release the executor which is used by the current thread/process 523 transforms.clean_unused_executors() 524 continue 525 if idx is None: 526 # When the queue is out of scope from master process, a None item can be fetched from the queue. 527 # Upon receiving None, worker process should check if eof is set. 528 if not eof.is_set(): 529 raise Exception("") 530 del idx_queue 531 del result_queue 532 return 533 if eof.is_set(): 534 del idx_queue 535 del result_queue 536 return 537 # Fetch data, any exception from __getitem__ will terminate worker and timeout master process 538 try: 539 result = dataset[idx] 540 except Exception: # pylint: disable=broad-except 541 result = ExceptionHandler(where="in GeneratorDataset worker process") 542 # Send data, block 543 while not eof.is_set(): 544 try: 545 result_queue.put(result, timeout=5) 546 except queue.Full: 547 if _main_process_already_exit(eof, is_multiprocessing, idx_queue, result_queue, ppid) is True: 548 del idx_queue 549 del result_queue 550 return 551 # If eof is not set, continue to put data to result_queue 552 continue 553 break 554 del result, idx 555 556 557class _GeneratorWorkerMt(threading.Thread): 558 """ 559 Worker process for multi-thread Generator. 560 """ 561 562 def __init__(self, dataset, eof, worker_id): 563 self.idx_queue = queue.Queue(16) 564 self.res_queue = queue.Queue(16) 565 super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, False), 566 name="GeneratorWorkerThread" + str(worker_id)) 567 568 def put(self, item): 569 """ 570 Put function for worker index queue. Never block. Raise queue.Full on failure. 571 """ 572 self.idx_queue.put_nowait(item) 573 574 def get(self): 575 """ 576 Get function for worker result queue. Block with timeout. 577 """ 578 return self.res_queue.get(timeout=30) 579 580 def queue_empty(self): 581 if not self.idx_queue.empty(): 582 logger.warning("idx_queue is not empty") 583 return False 584 if not self.res_queue.empty(): 585 logger.warning("res_queue is not empty") 586 return False 587 return True 588 589 590class _GeneratorWorkerMp(multiprocessing.Process): 591 """ 592 Worker process for multiprocess Generator. 593 """ 594 595 def __init__(self, dataset, eof, max_rowsize, queue_size, ppid, count, worker_id): 596 self.idx_queue = multiprocessing.Queue(queue_size) 597 if get_enable_shared_mem(): 598 self.res_queue = _SharedQueue(queue_size, count, max_rowsize=max_rowsize) 599 else: 600 self.res_queue = multiprocessing.Queue(queue_size) 601 self.idx_queue.cancel_join_thread() # Ensure that the process does not hung when exiting 602 super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid), 603 name="GeneratorWorkerProcess" + str(worker_id)) 604 605 def put(self, item): 606 """ 607 Put function for worker index queue. Never block. Raise queue.Full on failure. 608 """ 609 self.idx_queue.put_nowait(item) 610 611 def get(self): 612 """ 613 Get function for worker result queue. Block with timeout. 614 """ 615 # Relax 10s to 30s, since it sometimes will cause "Generator worker process timeout" 616 # when we run too many iterators with infinite epoch(num_epoch=-1) 617 return self.res_queue.get(timeout=30) 618 619 def queue_empty(self): 620 if not self.idx_queue.empty(): 621 logger.warning("idx_queue is not empty.") 622 return False 623 if not self.res_queue.empty(): 624 logger.warning("res_queue is not empty.") 625 return False 626 return True 627 628 def __del__(self): 629 # del all the Queue & SharedQueue when the iter had been deleted from ITERATORS_LIST 630 if hasattr(self, 'idx_queue'): 631 del self.idx_queue 632 if hasattr(self, 'res_queue'): 633 # del the queue when has 634 del self.res_queue 635 636 637class GeneratorDataset(MappableDataset, UnionBaseDataset): 638 """ 639 A source dataset that generates data from Python by invoking Python data source each epoch. 640 641 The column names and column types of generated dataset depend on Python data defined by users. 642 643 Args: 644 source (Union[Callable, Iterable, Random Accessible]): 645 A generator callable object, an iterable Python object or a random accessible Python object. 646 Callable source is required to return a tuple of NumPy arrays as a row of the dataset on source().next(). 647 Iterable source is required to return a tuple of NumPy arrays as a row of the dataset on 648 iter(source).next(). 649 Random accessible source is required to return a tuple of NumPy arrays as a row of the dataset on 650 source[idx]. 651 column_names (Union[str, list[str]], optional): List of column names of the dataset. Default: ``None`` . 652 Users are required to provide either column_names or schema. 653 column_types (list[mindspore.dtype], optional): List of column data types of the dataset. Default: ``None`` . 654 If provided, sanity check will be performed on generator output. 655 schema (Union[str, Schema], optional): Data format policy, which specifies the data types and shapes of the data 656 column to be read. Both JSON file path and objects constructed by :class:`mindspore.dataset.Schema` are 657 acceptable. Default: ``None`` . 658 num_samples (int, optional): The number of samples to be included in the dataset. 659 Default: ``None`` , all images. 660 num_parallel_workers (int, optional): Number of worker threads/subprocesses used to 661 fetch the dataset in parallel. Default: ``1``. 662 shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. 663 Default: ``None`` , expected order behavior shown in the table below. 664 sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. Random accessible 665 input is required. Default: ``None`` , expected order behavior shown in the table below. 666 num_shards (int, optional): Number of shards that the dataset will be divided into. Default: ``None`` . 667 Random accessible input is required. When this argument is specified, `num_samples` reflects the maximum 668 sample number of per shard. 669 shard_id (int, optional): The shard ID within `num_shards` . Default: ``None`` . 670 This argument must be specified only when `num_shards` is also specified. 671 Random accessible input is required. 672 python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This 673 option could be beneficial if the Python operation is computational heavy. Default: ``True``. 674 max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory 675 allocation to copy data between processes, the total occupied shared memory will increase as 676 ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set to -1, 677 shared memory will be dynamically allocated with the actual size of data. This is only used if 678 ``python_multiprocessing`` is set to True. Default: 16. 679 680 Raises: 681 RuntimeError: If source raises an exception during execution. 682 RuntimeError: If len of column_names does not match output len of source. 683 ValueError: If `num_parallel_workers` exceeds the max thread numbers. 684 ValueError: If sampler and shuffle are specified at the same time. 685 ValueError: If sampler and sharding are specified at the same time. 686 ValueError: If `num_shards` is specified but shard_id is None. 687 ValueError: If shard_id is specified but `num_shards` is None. 688 ValueError: If `shard_id` is not in range of [0, `num_shards` ). 689 690 Tutorial Examples: 691 - `Load & Process Data With Dataset Pipeline 692 <https://www.mindspore.cn/docs/en/master/api_python/samples/dataset/dataset_gallery.html>`_ 693 694 Note: 695 - If you configure `python_multiprocessing=True` (Default: ``True`` ) and `num_parallel_workers>1` 696 (default: ``1`` ) indicates that the multi-process mode is started for data load acceleration. 697 At this time, as the datasetiterates, the memory consumption of the subprocess will gradually increase, 698 mainly because the subprocess of the user-defined dataset obtains the member variables from the main 699 process in the Copy On Write way. 700 Example: If you define a dataset with `__ init__` function which contains a large number of member variable 701 data (for example, a very large file name list is loaded during the dataset construction) and uses the 702 multi-process mode, which may cause the problem of OOM (the estimated total memory usage is: 703 `(num_parallel_workers+1) * size of the parent process` ). The simplest solution is to replace Python objects 704 (such as list/dict/int/float/string) with non referenced data types 705 (such as Pandas, Numpy or PyArrow objects) for member variables, or load less meta data in member variables, 706 or configure `python_multiprocessing=False` to use multi-threading mode. 707 708 There are several classes/functions that can help you reduce the size of member variables, and you can choose 709 to use them: 710 711 1. :class:`mindspore.dataset.utils.LineReader`: Use this class to initialize your text file object in the 712 `__init__` function. Then read the file content based on the line number of the object with the `__getitem__` 713 function. 714 715 - Input `source` accepts user-defined Python functions (PyFuncs), Do not add network computing operators from 716 mindspore.nn and mindspore.ops or others into this `source` . 717 - The parameters `num_samples` , `shuffle` , `num_shards` , `shard_id` can be used to control the sampler 718 used in the dataset, and their effects when combined with parameter `sampler` are as follows. 719 720 .. include:: mindspore.dataset.sampler.txt 721 722 Examples: 723 >>> import mindspore.dataset as ds 724 >>> import numpy as np 725 >>> 726 >>> # 1) Multidimensional generator function as callable input. 727 >>> def generator_multidimensional(): 728 ... for i in range(64): 729 ... yield (np.array([[i, i + 1], [i + 2, i + 3]]),) 730 >>> 731 >>> dataset = ds.GeneratorDataset(source=generator_multidimensional, column_names=["multi_dimensional_data"]) 732 >>> 733 >>> # 2) Multi-column generator function as callable input. 734 >>> def generator_multi_column(): 735 ... for i in range(64): 736 ... yield np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]) 737 >>> 738 >>> dataset = ds.GeneratorDataset(source=generator_multi_column, column_names=["col1", "col2"]) 739 >>> 740 >>> # 3) Iterable dataset as iterable input. 741 >>> class MyIterable: 742 ... def __init__(self): 743 ... self._index = 0 744 ... self._data = np.random.sample((5, 2)) 745 ... self._label = np.random.sample((5, 1)) 746 ... 747 ... def __next__(self): 748 ... if self._index >= len(self._data): 749 ... raise StopIteration 750 ... else: 751 ... item = (self._data[self._index], self._label[self._index]) 752 ... self._index += 1 753 ... return item 754 ... 755 ... def __iter__(self): 756 ... self._index = 0 757 ... return self 758 ... 759 ... def __len__(self): 760 ... return len(self._data) 761 >>> 762 >>> dataset = ds.GeneratorDataset(source=MyIterable(), column_names=["data", "label"]) 763 >>> 764 >>> # 4) Random accessible dataset as random accessible input. 765 >>> class MyAccessible: 766 ... def __init__(self): 767 ... self._data = np.random.sample((5, 2)) 768 ... self._label = np.random.sample((5, 1)) 769 ... 770 ... def __getitem__(self, index): 771 ... return self._data[index], self._label[index] 772 ... 773 ... def __len__(self): 774 ... return len(self._data) 775 >>> 776 >>> dataset = ds.GeneratorDataset(source=MyAccessible(), column_names=["data", "label"]) 777 >>> 778 >>> # list, dict, tuple of Python is also random accessible 779 >>> dataset = ds.GeneratorDataset(source=[(np.array(0),), (np.array(1),), (np.array(2),)], column_names=["col"]) 780 """ 781 782 @check_generatordataset 783 def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, 784 num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, 785 python_multiprocessing=True, max_rowsize=6): 786 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 787 shuffle=shuffle, num_shards=num_shards, shard_id=shard_id) 788 if isinstance(source, builtins.zip): 789 # Although zip is iterable, it does not have the feature of repeated iteration, so pass it to the array. 790 self.source = [item for item in source] 791 else: 792 self.source = source 793 self.prepared_source = None # source to be sent to C++ 794 if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True: 795 self.num_parallel_workers = 1 796 logger.warning( 797 "Input 'source' of 'GeneratorDataset' includes network computing operators like in mindspore.nn, " 798 "mindspore.ops, mindspore.numpy module and etc, which do not support multi-thread compiling, recommend" 799 " to replace it with python implemented operator like numpy etc. Here decrease 'num_parallel_workers' " 800 "into 1.") 801 802 if platform.system().lower() == 'windows' and num_parallel_workers > 1 and python_multiprocessing: 803 logger.warning("Python multiprocessing is not supported on Windows platform.") 804 self.python_multiprocessing = python_multiprocessing if platform.system().lower() != 'windows' else False 805 if self.python_multiprocessing and get_debug_mode(): 806 logger.warning("Python multiprocessing is not supported in debug mode." 807 " Ignoring Python multiprocessing for GeneratorDataset.") 808 self.python_multiprocessing = False 809 810 self.column_names = to_list(column_names) 811 812 if column_types is not None: 813 self.column_types = mstypelist_to_detypelist(column_types) 814 else: 815 self.column_types = [] 816 817 self.schema = schema 818 if schema is not None: 819 self.schema = schema 820 if not isinstance(schema, Schema): 821 self.schema = Schema(schema) 822 # Move get dataset_size by len from parse to here, because self.source will 823 # lose attribution of '__len__' after deepcopy. 824 self.source_len = -1 # unknown 825 if hasattr(self.source, "__len__"): 826 self.source_len = len(self.source) 827 828 # if user defined sampler, update the self.source_len 829 if isinstance(self.sampler, samplers.Sampler) or hasattr(self.sampler, "__iter__"): 830 self.source_len = len(list(sampler)) 831 832 self.max_rowsize = max_rowsize 833 self.sample_fn = None 834 835 def __deepcopy__(self, memodict): 836 if id(self) in memodict: 837 return memodict[id(self)] 838 return self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__")) 839 840 def is_shuffled(self): 841 if self.sampler: 842 return self.sampler.is_shuffled() 843 return False 844 845 def is_sharded(self): 846 if self.sampler: 847 return self.sampler.is_sharded() 848 return False 849 850 def split(self, sizes, randomize=True): 851 if hasattr(self.source, "__getitem__"): 852 # If the source has __getitem__ attribute, call the split method of MappableDataset. 853 # Otherwise, call the split method of Dataset. 854 return super().split(sizes, randomize) 855 return super(MappableDataset, self).split(sizes, randomize) 856 857 def prepare_multiprocessing(self): 858 """Preprocessing of prepared_source.""" 859 sample_fn = None 860 if self.sampler is not None and hasattr(self.source, "__getitem__"): 861 # The reason why there is a try catch here is because when the new op is being constructed with shared 862 # memory enabled, there will be an exception thrown if there is not enough shared memory available 863 if self.source_len == -1: 864 raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!") 865 866 if self.num_parallel_workers > 1: 867 self.__validate_memory_usage() 868 869 sample_fn = SamplerFn(self.source, self.num_parallel_workers, self.python_multiprocessing, 870 self.max_rowsize) 871 self.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) 872 else: 873 self.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) 874 self.sample_fn = sample_fn 875 else: 876 self.sampler = None 877 self.sample_fn = sample_fn 878 self.source_len = min(self.source_len, self.num_samples) if self.num_samples != 0 else self.source_len 879 if not hasattr(self.source, "__iter__"): 880 # Use generator function if input callable 881 self.prepared_source = (lambda: _generator_fn(self.source, self.num_samples)) 882 else: 883 # Use iterator function if input is iterable 884 # Random accessible input is also iterable 885 self.prepared_source = (lambda: _iter_fn(self.source, self.num_samples)) 886 887 def parse(self, children=None): 888 self.prepare_multiprocessing() 889 if self.schema is None: 890 return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len, 891 self.sampler, self.num_parallel_workers) 892 schema = self.schema 893 if isinstance(schema, Schema): 894 schema = self.schema.cpp_schema 895 return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler, 896 self.num_parallel_workers) 897 898 def __validate_memory_usage(self): 899 """ 900 Check memory usage when mulit-processing mode, when 85% prompt warning and 100% raise error. 901 """ 902 if self.python_multiprocessing: 903 # if use num_parallel_workers is to large when python_multiprocessing=True which would cause 904 # OOM error get the num_shards 905 valid_num_shards = 1 906 if isinstance(self.sampler, samplers.DistributedSampler): 907 valid_num_shards = self.sampler.num_shards 908 elif self.num_shards is not None: 909 valid_num_shards = self.num_shards 910 911 # get process memory usage 912 process = psutil.Process(os.getpid()) 913 process_memory = process.memory_info().rss 914 sys_memory_available = psutil.virtual_memory().available 915 916 total_memory_maybe_used = process_memory * self.num_parallel_workers * valid_num_shards 917 if total_memory_maybe_used / sys_memory_available > 0.85: 918 valid_num_worker = math.floor(sys_memory_available * 0.85 / valid_num_shards / process_memory) 919 valid_num_worker = 1 if valid_num_worker <= 0 else valid_num_worker 920 info = "GeneratorDataset's num_parallel_workers: {} is too large which may cause a lot of memory " \ 921 "occupation (>85%) or out of memory(OOM) during multiprocessing. Therefore, it is recommended " \ 922 "to reduce num_parallel_workers to {} or smaller.".format(self.num_parallel_workers, 923 valid_num_worker) 924 logger.warning(info) 925 926 927class _NumpySlicesDataset: 928 """ 929 Mainly for dealing with several kinds of formats of Python data, and return one row each time. 930 """ 931 932 def __init__(self, data, column_list=None): 933 self.column_list = None 934 # Convert dict data into tuple 935 if isinstance(data, dict): 936 data = self.process_dict(data) 937 938 if isinstance(data, tuple): 939 self.data = data 940 else: 941 self.data = (data,) 942 943 # check whether the data length in each column is equal 944 data_len = [len(data_item) for data_item in self.data] 945 if data_len[1:] != data_len[:-1]: 946 raise ValueError("Data length in each column is not equal.") 947 948 # Init column_name 949 if column_list is not None: 950 self.column_list = column_list 951 elif self.column_list is None: 952 self.column_list = [] 953 column_num = len(self.data) 954 for i in range(column_num): 955 self.column_list.append("column_" + str(i)) 956 957 def __getitem__(self, index): 958 data_row = [d[index] for d in self.data] 959 data_res = tuple(data_row) 960 return data_res 961 962 def __len__(self): 963 return len(self.data[0]) 964 965 def process_dict(self, input_data): 966 """ 967 Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first. 968 """ 969 # Convert pandas like dict(has "values" column) into General dict 970 data_keys = list(input_data.keys()) 971 data_col = input_data[data_keys[0]] 972 if hasattr(data_col, "values"): 973 new_dict = {} 974 for key in data_keys: 975 item1 = input_data.pop(key) 976 new_dict[key] = item1.values 977 input_data = new_dict 978 979 # Convert the data in dict into tuple 980 data = () 981 keys = list(input_data.keys()) 982 self.column_list = keys 983 for key in keys: 984 value = input_data[key] 985 data = data + (list(value),) 986 987 return data 988 989 990class NumpySlicesDataset(GeneratorDataset): 991 """ 992 Creates a dataset with given data slices, mainly for loading Python data into dataset. 993 994 The column names and column types of generated dataset depend on Python data defined by users. 995 996 Args: 997 data (Union[list, tuple, dict]) Input of given data. Supported data types include: list, tuple, dict and other 998 NumPy formats. Input data will be sliced along the first dimension and generate additional rows, if input is 999 list, there will be one column in each row, otherwise there tends to be multi columns. Large data is not 1000 recommended to be loaded in this way as data is loading into memory. 1001 column_names (list[str], optional): List of column names of the dataset. Default: ``None`` . If `column_names` 1002 is not provided, the output column names will be named as the keys of dict when the input data is a dict, 1003 otherwise they will be named like column_0, column_1 ... 1004 num_samples (int, optional): The number of samples to be included in the dataset. Default: ``None`` , 1005 all samples. 1006 num_parallel_workers (int, optional): Number of worker subprocesses used to 1007 fetch the dataset in parallel. Default: ``1``. 1008 shuffle (bool, optional): Whether or not to perform shuffle on the dataset. 1009 Default: ``None`` , expected order behavior shown in the table below. 1010 sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. 1011 Default: ``None`` , expected order behavior shown in the table below. 1012 num_shards (int, optional): Number of shards that the dataset will be divided into. Default: ``None`` . 1013 When this argument is specified, `num_samples` reflects the max sample number of per shard. 1014 shard_id (int, optional): The shard ID within `num_shards` . Default: ``None`` . This argument must be 1015 specified only when `num_shards` is also specified. 1016 1017 Note: 1018 - The parameters `num_samples` , `shuffle` , `num_shards` , `shard_id` can be used to control the sampler 1019 used in the dataset, and their effects when combined with parameter `sampler` are as follows. 1020 1021 .. include:: mindspore.dataset.sampler.txt 1022 1023 Raises: 1024 RuntimeError: If len of column_names does not match output len of data. 1025 ValueError: If `num_parallel_workers` exceeds the max thread numbers. 1026 ValueError: If sampler and shuffle are specified at the same time. 1027 ValueError: If sampler and sharding are specified at the same time. 1028 ValueError: If `num_shards` is specified but shard_id is None. 1029 ValueError: If shard_id is specified but `num_shards` is None. 1030 ValueError: If `shard_id` is not in range of [0, `num_shards` ). 1031 1032 Tutorial Examples: 1033 - `Load & Process Data With Dataset Pipeline 1034 <https://www.mindspore.cn/docs/en/master/api_python/samples/dataset/dataset_gallery.html>`_ 1035 1036 Examples: 1037 >>> import mindspore.dataset as ds 1038 >>> # 1) Input data can be a list 1039 >>> data = [1, 2, 3] 1040 >>> dataset = ds.NumpySlicesDataset(data=data, column_names=["column_1"]) 1041 >>> 1042 >>> # 2) Input data can be a dictionary, and column_names will be its keys 1043 >>> data = {"a": [1, 2], "b": [3, 4]} 1044 >>> dataset = ds.NumpySlicesDataset(data=data) 1045 >>> 1046 >>> # 3) Input data can be a tuple of lists (or NumPy arrays), each tuple element refers to data in each column 1047 >>> data = ([1, 2], [3, 4], [5, 6]) 1048 >>> dataset = ds.NumpySlicesDataset(data=data, column_names=["column_1", "column_2", "column_3"]) 1049 >>> 1050 >>> # 4) Load data from CSV file 1051 >>> import pandas as pd 1052 >>> df = pd.read_csv(filepath_or_buffer=csv_dataset_dir[0]) 1053 >>> dataset = ds.NumpySlicesDataset(data=dict(df), shuffle=False) 1054 """ 1055 1056 @check_numpyslicesdataset 1057 def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, 1058 num_shards=None, shard_id=None): 1059 dataset = _NumpySlicesDataset(data, column_names) 1060 super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples, 1061 num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, 1062 num_shards=num_shards, shard_id=shard_id) 1063 1064 1065class _PaddedDataset: 1066 """ 1067 Mainly for combining false samples provided by users into a dataset. 1068 1069 Args: 1070 padded_samples (list(dict)): Data provided by user to be added to the initial Dataset. 1071 """ 1072 1073 def __init__(self, padded_samples): 1074 self.column_names = list(padded_samples[0].keys()) 1075 self.padded_samples = padded_samples 1076 1077 def __getitem__(self, item): 1078 return (self.padded_samples[item][key] for key in self.column_names) 1079 1080 def __len__(self): 1081 return len(self.padded_samples) 1082 1083 1084class PaddedDataset(GeneratorDataset): 1085 """ 1086 Creates a dataset with filler data provided by user. 1087 1088 Mainly used to add to the original dataset and assign it to the corresponding shard. 1089 1090 Args: 1091 padded_samples (list(dict)): Samples provided by user. 1092 1093 Raises: 1094 TypeError: If padded_samples is not an instance of list. 1095 TypeError: If the element of padded_samples is not an instance of dict. 1096 ValueError: If the padded_samples is empty. 1097 1098 Tutorial Examples: 1099 - `Load & Process Data With Dataset Pipeline 1100 <https://www.mindspore.cn/docs/en/master/api_python/samples/dataset/dataset_gallery.html>`_ 1101 1102 Examples: 1103 >>> import mindspore.dataset as ds 1104 >>> import numpy as np 1105 >>> data = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}] 1106 >>> dataset = ds.PaddedDataset(padded_samples=data) 1107 """ 1108 1109 @check_paddeddataset 1110 def __init__(self, padded_samples): 1111 dataset = _PaddedDataset(padded_samples) 1112 super().__init__(dataset, column_names=dataset.column_names, num_shards=None, shard_id=None, shuffle=False) 1113 self._dataset_size = len(dataset.padded_samples) 1114 self.padded_samples = padded_samples 1115