• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019-2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16This file contains basic classes that help users do flexible dataset loading.
17You can define your own dataset loading class, and use GeneratorDataset to help load data.
18After declaring the dataset object, you can further apply dataset operations
19(e.g. filter, skip, concat, map, batch) on it.
20"""
21import builtins
22import copy
23import errno
24import math
25import os
26import signal
27import time
28import multiprocessing
29from multiprocessing.util import Finalize
30import queue
31from functools import partial
32import subprocess
33import threading
34import weakref
35import platform
36import psutil
37import numpy as np
38
39import mindspore._c_dataengine as cde
40
41from mindspore.common import Tensor
42from mindspore import log as logger
43
44from .datasets import UnionBaseDataset, MappableDataset, Schema, to_list, _PythonMultiprocessing, _check_shm_usage
45from . import samplers
46from .queue import _SharedQueue
47from .validators import check_generatordataset, check_numpyslicesdataset, check_paddeddataset
48from ..core.config import get_enable_shared_mem, get_prefetch_size, get_multiprocessing_timeout_interval, \
49    get_enable_watchdog, get_debug_mode
50from ..core.datatypes import mstypelist_to_detypelist
51from ..core.py_util_helpers import ExceptionHandler
52from ..transforms import transforms
53
54
55def _iter_fn(dataset, num_samples):
56    """
57    Generator function wrapper for iterable dataset.
58    """
59    if num_samples is not None and num_samples != 0:
60        ds_iter = iter(dataset)
61        for _ in range(num_samples):
62            try:
63                val = next(ds_iter)
64            except StopIteration:
65                return
66            # convert output tensors to ndarrays
67            yield _convert_row(val)
68    else:
69        for val in dataset:
70            # convert output tensors to ndarrays
71            yield _convert_row(val)
72
73
74def _generator_fn(generator, num_samples):
75    """
76    Generator function wrapper for generator function dataset.
77    """
78    if num_samples is not None and num_samples != 0:
79        gen_iter = generator()
80        for _ in range(num_samples):
81            try:
82                val = next(gen_iter)
83            except StopIteration:
84                return
85            yield _convert_row(val)
86    else:
87        gen_iter = generator()
88        for val in gen_iter:
89            yield _convert_row(val)
90
91
92def _cpp_sampler_fn(sample_ids, dataset):
93    """
94    Generator function wrapper for mappable dataset with cpp sampler.
95    """
96    if not isinstance(sample_ids, np.ndarray):
97        raise RuntimeError("Sample IDs are not in a numpy array.")
98    if sample_ids.size == 0:
99        raise RuntimeError("Sampler passed an empty sample IDs list.")
100
101    for i in sample_ids:
102        val = dataset[i]
103        # convert output tensors to ndarrays
104        yield _convert_row(val)
105
106
107def _cpp_sampler_fn_mp(sample_ids, sample_fn):
108    """
109    Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
110    """
111    if not isinstance(sample_ids, np.ndarray):
112        raise RuntimeError("Sample IDs are not in a numpy array.")
113    if sample_ids.size == 0:
114        raise RuntimeError("Sampler passed an empty sample IDs list.")
115
116    return sample_fn.process(sample_ids)
117
118
119def _fill_worker_indices(workers, indices, idx_cursor, worker_to_quit):
120    """
121    Worker index queue filler, fill worker index queue in round robin order or QUIT flag.
122    """
123    num_worker = len(workers)
124    if idx_cursor < len(indices):
125        while idx_cursor < len(indices):
126            try:
127                workers[idx_cursor % num_worker].put(indices[idx_cursor])
128                idx_cursor += 1
129            except queue.Full:
130                break
131    else:
132        for i in range(num_worker):
133            # just put only one QUIT flag to the sub-thread / sub-process
134            if str(i) not in worker_to_quit:
135                try:
136                    workers[i].put("QUIT")
137                    worker_to_quit[str(i)] = "QUIT"
138                except queue.Full:
139                    continue
140    return idx_cursor, worker_to_quit
141
142
143def _convert_row(row):
144    """
145    Convert Op return value to numpy, or keep as a dict (if already a dict)
146    """
147
148    # convert single item to np.array
149    prim_type = (int, float, str, bytes, np.ndarray, Tensor, np.number, np.bool_)
150    if isinstance(row, prim_type):
151        if isinstance(row, Tensor):  # mindspore.Tensor
152            item = row.asnumpy()
153        else:
154            item = np.array(row, copy=False)
155            if item.dtype == 'object':
156                raise TypeError("Data type of the input or its converted Numpy array is expected to be " \
157                                "int or float or str, but got {}.".format(item.dtype))
158        return tuple([item])
159
160    if isinstance(row, dict):
161        return tuple([row])
162
163    value = []
164    # convert each item to np.array
165    idx = 0
166    for x in row:
167        idx += 1
168        if isinstance(x, Tensor):  # mindspore.Tensor
169            value.append(x.asnumpy())
170        elif isinstance(x, dict):
171            value.append(x)
172        else:
173            item = np.array(x, copy=False)
174            if item.dtype == 'object':
175                raise TypeError("Data type of {}th item of the input or its converted Numpy array is expected to be " \
176                                "int or float or str, but got {}.".format(idx, item.dtype))
177            value.append(item)
178    return tuple(value)
179
180
181class SamplerFn:
182    """
183    Multiprocessing or multithread generator function wrapper master process.
184    """
185
186    def __init__(self, dataset, num_worker, multi_process, max_rowsize):
187        self.workers = []
188        self.dataset = dataset
189        self.num_worker = num_worker
190        self.multi_process = multi_process
191        self.max_rowsize = max_rowsize
192        self.need_join = False
193        self.ppid = os.getpid()
194        self.pids = []
195        self.check_interval = get_multiprocessing_timeout_interval()  # the interval of check queue's size
196        self._final_join = True
197
198        # Event for end of epoch
199        if multi_process is True:
200            try:
201                self.eof = multiprocessing.Event()
202            except Exception:
203                raise RuntimeError("Init multiprocessing.Event() failed, This might be caused by insufficient shm,"
204                                   + " and the recommended shm size is at least 5 GB.")
205        else:
206            self.eof = threading.Event()
207        # Create workers
208
209        # get default queue size and adjust queuesize per worker if there are large # workers
210        queue_size = get_prefetch_size()
211        queue_size = min(queue_size, queue_size * 4 // num_worker)
212        queue_size = max(2, queue_size)
213
214        if multi_process and get_enable_shared_mem():
215            # generator dataset use idx_queue and res_queue to transfer data between main and subprocess
216            # idx_queue is used multiprocess.Queue which is not shared memory, so it's size is 0.
217            # res_queue is used shared memory, so it' size is max_rowsize which is defined by user.
218            _check_shm_usage(num_worker, queue_size, 0, max_rowsize)
219        self.count = multiprocessing.Value('i', 0)
220        for worker_id in range(num_worker):
221            if multi_process is True:
222                try:
223                    worker = _GeneratorWorkerMp(dataset, self.eof, max_rowsize, queue_size, self.ppid, self.count,
224                                                worker_id)
225                    worker.daemon = True
226                    # When multi processes fork a subprocess, the lock of the main process is copied to the subprocess,
227                    # which may cause deadlock. Therefore, the subprocess startup is performed in the initialization
228                    # phase. In this phase, the main process is not locked.
229                    worker.start()
230                except OSError as e:
231                    if e.errno == errno.EMFILE:
232                        raise RuntimeError("Failed to launch multiprocessing of GeneratorDataset: "
233                                           "Too many open files. Please check if `num_parallel_workers` "
234                                           "is set too large, or you are creating iterators multiple times. "
235                                           "You can also increase the limit using `ulimit -n` in the shell "
236                                           "to avoid this error.")
237                    raise
238                except Exception as e:
239                    raise RuntimeError("Failed to launch multiprocessing of GeneratorDataset: {0}".format(e))
240                self.pids.append(worker.pid)
241                self.need_join = True
242            else:
243                worker = _GeneratorWorkerMt(dataset, self.eof, worker_id)
244                worker.daemon = True
245            self.workers.append(worker)
246        self._launch_cleanup_worker(multi_process=multi_process)
247
248    def _interval_log(self, i, start_time, wait_count):
249        cost_time = int(time.time()) - start_time
250        if cost_time / self.check_interval >= wait_count:
251            wait_count += 1
252            self._log_stuck_warning(self.workers[i % self.num_worker], cost_time)
253        return wait_count
254
255    def process(self, indices):
256        """
257        The main process, start the child process or child thread, and fill the index queue.
258        Get the result and return.
259        """
260        for w in self.workers:
261            # Check whether the queue of the subprocess is empty.
262            if not w.queue_empty():
263                # in failover reset scenario the QUIT flag should be pop first
264                while w.idx_queue.qsize() > 0:
265                    try:
266                        result = w.idx_queue.get(timeout=1)
267                        if result != "QUIT":
268                            raise Exception("The queue of the subprocess is not empty.")
269                    except queue.Empty:
270                        continue
271            # Start all workers
272            if not w.is_alive():
273                w.start()
274
275        # Fill initial index queues
276        idx_cursor = 0
277        # worker to quit
278        worker_to_quit = {}
279        idx_cursor, worker_to_quit = _fill_worker_indices(self.workers, indices, idx_cursor, worker_to_quit)
280
281        # Fetch results
282        for i in range(len(indices)):
283            if self.eof.is_set():
284                self._stop_subprocess()
285                return
286            if self.multi_process is True and not psutil.pid_exists(self.workers[i % self.num_worker].pid):
287                self._stop_subprocess()
288                return
289            # Fetch result and put index
290            try:
291                # To avoid get timeout from queue, check the res_queue size.
292                start_time = int(time.time())
293                wait_count = 1
294                while self.workers[i % self.num_worker].res_queue.empty():
295                    if self.eof.is_set():
296                        logger.warning("Generator receives a termination signal, stop waiting for data "
297                                       "from subprocess.")
298                        self._stop_subprocess()
299                        return
300                    time.sleep(0.1)
301                    wait_count = self._interval_log(i, start_time, wait_count)
302                result = self.workers[i % self.num_worker].get()
303                # Because there is no need to copy when creating Tensors in the C++layer, it reduces the time
304                # from np.ndarray to C++Tensor creation. However, when using shared memory in multiple processes,
305                # the address of the shared memory will always be passed to subsequent nodes in the dataset pipeline,
306                # and the shared memory will also be written by the current node, causing dirty data to be accessed
307                # by subsequent nodes in the pipeline. So make a memory copy here to solve the problem of
308                # shared memory being contaminated.
309                if self.multi_process is True and get_enable_shared_mem():
310                    result = copy.deepcopy(result)
311                if isinstance(result, ExceptionHandler):
312                    result.reraise()
313            except queue.Empty:
314                self._stop_subprocess()
315                raise Exception("Generator worker process timeout.")
316            except KeyboardInterrupt:
317                self._stop_subprocess()
318                raise Exception("Generator worker receives KeyboardInterrupt.")
319            if self.eof.is_set():
320                self._stop_subprocess()
321                return
322
323            idx_cursor, worker_to_quit = _fill_worker_indices(self.workers, indices, idx_cursor, worker_to_quit)
324
325            yield _convert_row(result)
326
327    def _log_stuck_warning(self, worker, waiting_time):
328        """
329        Log warning of the stuck worker, containing the worker ID, waiting time and
330        the current stack (if py-spy installed).
331
332        Args:
333            worker (Union[threading.Thread, multiprocessing.Process]): The worker instance.
334            waiting_time (int): The waiting time for getting data from the worker.
335        """
336        if self.multi_process:
337            stuck_worker_id = worker.pid
338            worker_type = "process"
339            stuck_pid = stuck_worker_id
340        else:
341            if hasattr(worker, "native_id"):
342                # only supported since Python 3.8
343                stuck_worker_id = worker.native_id
344            else:
345                stuck_worker_id = worker.ident
346            worker_type = "thread"
347            stuck_pid = os.getpid()  # get the process ID of the stuck thread
348        warning_message = "Has been waiting for data from Generator worker {0} ID '{1}' " \
349                          "for more than {2} seconds. Please check if the user defined " \
350                          "dataset of GeneratorDataset has a dead loop, or is processing " \
351                          "too slowly. ".format(worker_type, stuck_worker_id, waiting_time)
352        install_status, _ = subprocess.getstatusoutput("py-spy --version")
353        if install_status == 0:
354            stack = subprocess.getoutput("py-spy dump -p {}".format(stuck_pid))
355            warning_message += "Below is the stack of this worker:\n{0}\n".format(stack)
356        else:
357            warning_message += "You can install py-spy via `pip install py-spy`, then " \
358                               "stop and rerun your script to get the current stack. "
359        warning_message += "If it is not a problem, you can adjust the printing frequency of this log via " \
360                           "the `mindspore.dataset.config.set_multiprocessing_timeout_interval` interface."
361        logger.warning(warning_message)
362
363    def _launch_cleanup_worker(self, multi_process):
364        """
365        We need a extra thread and process if main process or subprocess was killed.
366
367        Args:
368            multi_process: Whether use multiprocess.
369        """
370        if multi_process is True and platform.system().lower() != 'windows':
371            _clean_worker_func = _PythonMultiprocessing._clean_process  # pylint: disable=W0212
372            self.cleaning_process = multiprocessing.Process(target=_clean_worker_func,
373                                                            name="GeneratorCleanProcess",
374                                                            args=(self.ppid, self.workers, self.eof))
375            self.cleaning_process.daemon = True
376            self.cleaning_process.start()
377
378            if get_enable_watchdog():
379                self.eot = threading.Event()
380                self.watch_dog = threading.Thread(target=_PythonMultiprocessing._watch_dog,  # pylint: disable=W0212
381                                                  name="GeneratorWatchDog",
382                                                  args=(self.eot, self.workers + [self.cleaning_process]))
383                self.watch_dog.daemon = True
384                self.watch_dog.start()
385
386                if self._final_join is True:
387                    self._jointhread = Finalize(
388                        self.watch_dog, self._finalize_join,
389                        args=(weakref.ref(self.watch_dog), self.eot),
390                        exitpriority=-5
391                    )
392
393    def _stop_subprocess(self):
394        """Only the main process can call join."""
395        if self.need_join is True and self.ppid == os.getpid():
396            if hasattr(self, 'eof') and self.eof is not None:
397                self.eof.set()
398            # close the watch dog first
399            self._abort_watchdog()
400            self.need_join = False
401            for w in self.workers:
402                if self.multi_process is True and hasattr(w, '_closed') and w._closed is False:  # pylint: disable=W0212
403                    try:
404                        # del the queue first
405                        del w.res_queue
406                        del w.idx_queue
407
408                        # let the quit event notify the worker process to exit
409                        w.join(timeout=5)
410                        if w.is_alive():
411                            # if the worker process did not exit, it may hang, try to terminate it
412                            w.terminate()
413                            w.close()
414                    except Exception:  # pylint: disable=W0703
415                        # Block all errors when join
416                        continue
417
418            # release the file descriptor handle
419            check_interval = get_multiprocessing_timeout_interval()
420            for w in self.workers:
421                try:
422                    subprocess_file_descriptor = w.sentinel
423                    st = time.time()
424                    while _PythonMultiprocessing.is_process_alive(w.pid):
425                        time.sleep(0.01)  # sleep 10ms, waiting for the subprocess exit
426                        if time.time() - st > check_interval:
427                            logger.warning("Waiting for the subprocess worker [{}] to exit.".format(w.pid))
428                            st += check_interval
429                except ValueError as e:
430                    if "process object is closed" in str(e):
431                        continue
432                    raise e
433                try:
434                    if w.is_alive():
435                        os.close(subprocess_file_descriptor)
436                except OSError as e:
437                    # Maybe the file descriptor had been released, so ignore the 'Bad file descriptor'
438                    if "Bad file descriptor" not in str(e):
439                        raise e
440
441            self.workers.clear()
442            self.workers = None
443
444    def _abort_watchdog(self):
445        """Let watchdog quit."""
446        if hasattr(self, 'eot') and self.eot is not None and not self.eot.is_set():
447            self.eot.set()
448        if hasattr(self, 'cleaning_process') and self.cleaning_process is not None:
449            # let the quit event notify the cleaning process to exit
450            self.cleaning_process.join(timeout=5)
451            if self.cleaning_process.is_alive():
452                # if the cleaning process did not exit, it may hang, try to terminate it
453                _PythonMultiprocessing._terminate_processes([self.cleaning_process])  # pylint: disable=W0212
454            del self.cleaning_process
455        if hasattr(self, 'count'):
456            del self.count
457
458    @classmethod
459    def _finalize_join(cls, twr, eot):
460        thread = twr()
461        if thread is not None:
462            if eot is not None and not eot.is_set():
463                eot.set()
464            thread.join()
465
466    def __del__(self):
467        try:
468            self._stop_subprocess()
469        except TypeError:
470            pass
471
472    def __deepcopy__(self, memodict, exclude=()):
473        self.__init__(self.dataset, self.num_worker, self.multi_process, self.max_rowsize)
474
475
476def _subprocess_handle(eof, signum, frame):
477    threading.Thread(target=eof.set()).start()
478
479
480def _ignore_sigint(is_multiprocessing):
481    """
482    We need to ignore sigint signal here so subprocesses can exit normally and clear.
483    """
484    if is_multiprocessing:
485        signal.signal(signal.SIGINT, signal.SIG_IGN)
486
487
488def _main_process_already_exit(eof, is_multiprocessing, idx_queue, result_queue, ppid):
489    """
490    Judge whether main process already exit.
491    """
492    if eof.is_set() or (is_multiprocessing and platform.system().lower() != 'windows' and
493                        not _PythonMultiprocessing.is_process_alive(ppid)):
494        if is_multiprocessing:
495            idx_queue.cancel_join_thread()
496            result_queue.cancel_join_thread()
497        return True
498    return False
499
500
501def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing, ppid=-1):
502    """
503    Multithread or multiprocess generator worker process loop.
504    """
505    if is_multiprocessing:
506        result_queue.cancel_join_thread()  # Ensure that the process does not hung when exiting
507        signal.signal(signal.SIGTERM, partial(_subprocess_handle, eof))
508    while not eof.is_set():
509        _ignore_sigint(is_multiprocessing=is_multiprocessing)
510
511        # Fetch index, block
512        try:
513            idx = idx_queue.get(timeout=1)
514        except queue.Empty:
515            if _main_process_already_exit(eof, is_multiprocessing, idx_queue, result_queue, ppid) is True:
516                del idx_queue
517                del result_queue
518                return
519            # If end-of-file (eof) is not set, continue to get data from idx_queue
520            continue
521        if idx == "QUIT":
522            # all the data had been processed, so we release the executor which is used by the current thread/process
523            transforms.clean_unused_executors()
524            continue
525        if idx is None:
526            # When the queue is out of scope from master process, a None item can be fetched from the queue.
527            # Upon receiving None, worker process should check if eof is set.
528            if not eof.is_set():
529                raise Exception("")
530            del idx_queue
531            del result_queue
532            return
533        if eof.is_set():
534            del idx_queue
535            del result_queue
536            return
537        # Fetch data, any exception from __getitem__ will terminate worker and timeout master process
538        try:
539            result = dataset[idx]
540        except Exception:  # pylint: disable=broad-except
541            result = ExceptionHandler(where="in GeneratorDataset worker process")
542        # Send data, block
543        while not eof.is_set():
544            try:
545                result_queue.put(result, timeout=5)
546            except queue.Full:
547                if _main_process_already_exit(eof, is_multiprocessing, idx_queue, result_queue, ppid) is True:
548                    del idx_queue
549                    del result_queue
550                    return
551                # If eof is not set, continue to put data to result_queue
552                continue
553            break
554        del result, idx
555
556
557class _GeneratorWorkerMt(threading.Thread):
558    """
559    Worker process for multi-thread Generator.
560    """
561
562    def __init__(self, dataset, eof, worker_id):
563        self.idx_queue = queue.Queue(16)
564        self.res_queue = queue.Queue(16)
565        super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, False),
566                         name="GeneratorWorkerThread" + str(worker_id))
567
568    def put(self, item):
569        """
570        Put function for worker index queue. Never block. Raise queue.Full on failure.
571        """
572        self.idx_queue.put_nowait(item)
573
574    def get(self):
575        """
576        Get function for worker result queue. Block with timeout.
577        """
578        return self.res_queue.get(timeout=30)
579
580    def queue_empty(self):
581        if not self.idx_queue.empty():
582            logger.warning("idx_queue is not empty")
583            return False
584        if not self.res_queue.empty():
585            logger.warning("res_queue is not empty")
586            return False
587        return True
588
589
590class _GeneratorWorkerMp(multiprocessing.Process):
591    """
592    Worker process for multiprocess Generator.
593    """
594
595    def __init__(self, dataset, eof, max_rowsize, queue_size, ppid, count, worker_id):
596        self.idx_queue = multiprocessing.Queue(queue_size)
597        if get_enable_shared_mem():
598            self.res_queue = _SharedQueue(queue_size, count, max_rowsize=max_rowsize)
599        else:
600            self.res_queue = multiprocessing.Queue(queue_size)
601        self.idx_queue.cancel_join_thread()  # Ensure that the process does not hung when exiting
602        super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid),
603                         name="GeneratorWorkerProcess" + str(worker_id))
604
605    def put(self, item):
606        """
607        Put function for worker index queue. Never block. Raise queue.Full on failure.
608        """
609        self.idx_queue.put_nowait(item)
610
611    def get(self):
612        """
613        Get function for worker result queue. Block with timeout.
614        """
615        # Relax 10s to 30s, since it sometimes will cause "Generator worker process timeout"
616        # when we run too many iterators with infinite epoch(num_epoch=-1)
617        return self.res_queue.get(timeout=30)
618
619    def queue_empty(self):
620        if not self.idx_queue.empty():
621            logger.warning("idx_queue is not empty.")
622            return False
623        if not self.res_queue.empty():
624            logger.warning("res_queue is not empty.")
625            return False
626        return True
627
628    def __del__(self):
629        # del all the Queue & SharedQueue when the iter had been deleted from ITERATORS_LIST
630        if hasattr(self, 'idx_queue'):
631            del self.idx_queue
632        if hasattr(self, 'res_queue'):
633            # del the queue when has
634            del self.res_queue
635
636
637class GeneratorDataset(MappableDataset, UnionBaseDataset):
638    """
639    A source dataset that generates data from Python by invoking Python data source each epoch.
640
641    The column names and column types of generated dataset depend on Python data defined by users.
642
643    Args:
644        source (Union[Callable, Iterable, Random Accessible]):
645            A generator callable object, an iterable Python object or a random accessible Python object.
646            Callable source is required to return a tuple of NumPy arrays as a row of the dataset on source().next().
647            Iterable source is required to return a tuple of NumPy arrays as a row of the dataset on
648            iter(source).next().
649            Random accessible source is required to return a tuple of NumPy arrays as a row of the dataset on
650            source[idx].
651        column_names (Union[str, list[str]], optional): List of column names of the dataset. Default: ``None`` .
652            Users are required to provide either column_names or schema.
653        column_types (list[mindspore.dtype], optional): List of column data types of the dataset. Default: ``None`` .
654            If provided, sanity check will be performed on generator output.
655        schema (Union[str, Schema], optional): Data format policy, which specifies the data types and shapes of the data
656            column to be read. Both JSON file path and objects constructed by :class:`mindspore.dataset.Schema` are
657            acceptable. Default: ``None`` .
658        num_samples (int, optional): The number of samples to be included in the dataset.
659            Default: ``None`` , all images.
660        num_parallel_workers (int, optional): Number of worker threads/subprocesses used to
661            fetch the dataset in parallel. Default: ``1``.
662        shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
663            Default: ``None`` , expected order behavior shown in the table below.
664        sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. Random accessible
665            input is required. Default: ``None`` , expected order behavior shown in the table below.
666        num_shards (int, optional): Number of shards that the dataset will be divided into. Default: ``None`` .
667            Random accessible input is required. When this argument is specified, `num_samples` reflects the maximum
668            sample number of per shard.
669        shard_id (int, optional): The shard ID within `num_shards` . Default: ``None`` .
670            This argument must be specified only when `num_shards` is also specified.
671            Random accessible input is required.
672        python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
673            option could be beneficial if the Python operation is computational heavy. Default: ``True``.
674        max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory
675            allocation to copy data between processes, the total occupied shared memory will increase as
676            ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set to -1,
677            shared memory will be dynamically allocated with the actual size of data. This is only used if
678            ``python_multiprocessing`` is set to True. Default: 16.
679
680    Raises:
681        RuntimeError: If source raises an exception during execution.
682        RuntimeError: If len of column_names does not match output len of source.
683        ValueError: If `num_parallel_workers` exceeds the max thread numbers.
684        ValueError: If sampler and shuffle are specified at the same time.
685        ValueError: If sampler and sharding are specified at the same time.
686        ValueError: If `num_shards` is specified but shard_id is None.
687        ValueError: If shard_id is specified but `num_shards` is None.
688        ValueError: If `shard_id` is not in range of [0, `num_shards` ).
689
690    Tutorial Examples:
691        - `Load & Process Data With Dataset Pipeline
692          <https://www.mindspore.cn/docs/en/master/api_python/samples/dataset/dataset_gallery.html>`_
693
694    Note:
695        - If you configure `python_multiprocessing=True` (Default: ``True`` ) and `num_parallel_workers>1`
696          (default: ``1`` ) indicates that the multi-process mode is started for data load acceleration.
697          At this time, as the datasetiterates, the memory consumption of the subprocess will gradually increase,
698          mainly because the subprocess of the user-defined dataset obtains the member variables from the main
699          process in the Copy On Write way.
700          Example: If you define a dataset with `__ init__` function which contains a large number of member variable
701          data (for example, a very large file name list is loaded during the dataset construction) and uses the
702          multi-process mode, which may cause the problem of OOM (the estimated total memory usage is:
703          `(num_parallel_workers+1) * size of the parent process` ). The simplest solution is to replace Python objects
704          (such as list/dict/int/float/string) with non referenced data types
705          (such as Pandas, Numpy or PyArrow objects) for member variables, or load less meta data in member variables,
706          or configure `python_multiprocessing=False` to use multi-threading mode.
707
708          There are several classes/functions that can help you reduce the size of member variables, and you can choose
709          to use them:
710
711          1. :class:`mindspore.dataset.utils.LineReader`: Use this class to initialize your text file object in the
712          `__init__` function. Then read the file content based on the line number of the object with the `__getitem__`
713          function.
714
715        - Input `source` accepts user-defined Python functions (PyFuncs), Do not add network computing operators from
716          mindspore.nn and mindspore.ops or others into this `source` .
717        - The parameters `num_samples` , `shuffle` , `num_shards` , `shard_id` can be used to control the sampler
718          used in the dataset, and their effects when combined with parameter `sampler` are as follows.
719
720    .. include:: mindspore.dataset.sampler.txt
721
722    Examples:
723        >>> import mindspore.dataset as ds
724        >>> import numpy as np
725        >>>
726        >>> # 1) Multidimensional generator function as callable input.
727        >>> def generator_multidimensional():
728        ...     for i in range(64):
729        ...         yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
730        >>>
731        >>> dataset = ds.GeneratorDataset(source=generator_multidimensional, column_names=["multi_dimensional_data"])
732        >>>
733        >>> # 2) Multi-column generator function as callable input.
734        >>> def generator_multi_column():
735        ...     for i in range(64):
736        ...         yield np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])
737        >>>
738        >>> dataset = ds.GeneratorDataset(source=generator_multi_column, column_names=["col1", "col2"])
739        >>>
740        >>> # 3) Iterable dataset as iterable input.
741        >>> class MyIterable:
742        ...     def __init__(self):
743        ...         self._index = 0
744        ...         self._data = np.random.sample((5, 2))
745        ...         self._label = np.random.sample((5, 1))
746        ...
747        ...     def __next__(self):
748        ...         if self._index >= len(self._data):
749        ...             raise StopIteration
750        ...         else:
751        ...             item = (self._data[self._index], self._label[self._index])
752        ...             self._index += 1
753        ...             return item
754        ...
755        ...     def __iter__(self):
756        ...         self._index = 0
757        ...         return self
758        ...
759        ...     def __len__(self):
760        ...         return len(self._data)
761        >>>
762        >>> dataset = ds.GeneratorDataset(source=MyIterable(), column_names=["data", "label"])
763        >>>
764        >>> # 4) Random accessible dataset as random accessible input.
765        >>> class MyAccessible:
766        ...     def __init__(self):
767        ...         self._data = np.random.sample((5, 2))
768        ...         self._label = np.random.sample((5, 1))
769        ...
770        ...     def __getitem__(self, index):
771        ...         return self._data[index], self._label[index]
772        ...
773        ...     def __len__(self):
774        ...         return len(self._data)
775        >>>
776        >>> dataset = ds.GeneratorDataset(source=MyAccessible(), column_names=["data", "label"])
777        >>>
778        >>> # list, dict, tuple of Python is also random accessible
779        >>> dataset = ds.GeneratorDataset(source=[(np.array(0),), (np.array(1),), (np.array(2),)], column_names=["col"])
780    """
781
782    @check_generatordataset
783    def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
784                 num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
785                 python_multiprocessing=True, max_rowsize=6):
786        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
787                         shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
788        if isinstance(source, builtins.zip):
789            # Although zip is iterable, it does not have the feature of repeated iteration, so pass it to the array.
790            self.source = [item for item in source]
791        else:
792            self.source = source
793        self.prepared_source = None  # source to be sent to C++
794        if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True:
795            self.num_parallel_workers = 1
796            logger.warning(
797                "Input 'source' of 'GeneratorDataset' includes network computing operators like in mindspore.nn, "
798                "mindspore.ops, mindspore.numpy module and etc, which do not support multi-thread compiling, recommend"
799                " to replace it with python implemented operator like numpy etc. Here decrease 'num_parallel_workers' "
800                "into 1.")
801
802        if platform.system().lower() == 'windows' and num_parallel_workers > 1 and python_multiprocessing:
803            logger.warning("Python multiprocessing is not supported on Windows platform.")
804        self.python_multiprocessing = python_multiprocessing if platform.system().lower() != 'windows' else False
805        if self.python_multiprocessing and get_debug_mode():
806            logger.warning("Python multiprocessing is not supported in debug mode."
807                           " Ignoring Python multiprocessing for GeneratorDataset.")
808            self.python_multiprocessing = False
809
810        self.column_names = to_list(column_names)
811
812        if column_types is not None:
813            self.column_types = mstypelist_to_detypelist(column_types)
814        else:
815            self.column_types = []
816
817        self.schema = schema
818        if schema is not None:
819            self.schema = schema
820            if not isinstance(schema, Schema):
821                self.schema = Schema(schema)
822        # Move get dataset_size by len from parse to here, because self.source will
823        # lose attribution of '__len__' after deepcopy.
824        self.source_len = -1  # unknown
825        if hasattr(self.source, "__len__"):
826            self.source_len = len(self.source)
827
828            # if user defined sampler, update the self.source_len
829            if isinstance(self.sampler, samplers.Sampler) or hasattr(self.sampler, "__iter__"):
830                self.source_len = len(list(sampler))
831
832        self.max_rowsize = max_rowsize
833        self.sample_fn = None
834
835    def __deepcopy__(self, memodict):
836        if id(self) in memodict:
837            return memodict[id(self)]
838        return self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__"))
839
840    def is_shuffled(self):
841        if self.sampler:
842            return self.sampler.is_shuffled()
843        return False
844
845    def is_sharded(self):
846        if self.sampler:
847            return self.sampler.is_sharded()
848        return False
849
850    def split(self, sizes, randomize=True):
851        if hasattr(self.source, "__getitem__"):
852            # If the source has __getitem__ attribute, call the split method of MappableDataset.
853            # Otherwise, call the split method of Dataset.
854            return super().split(sizes, randomize)
855        return super(MappableDataset, self).split(sizes, randomize)
856
857    def prepare_multiprocessing(self):
858        """Preprocessing of prepared_source."""
859        sample_fn = None
860        if self.sampler is not None and hasattr(self.source, "__getitem__"):
861            # The reason why there is a try catch here is because when the new op is being constructed with shared
862            # memory enabled, there will be an exception thrown if there is not enough shared memory available
863            if self.source_len == -1:
864                raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!")
865
866            if self.num_parallel_workers > 1:
867                self.__validate_memory_usage()
868
869                sample_fn = SamplerFn(self.source, self.num_parallel_workers, self.python_multiprocessing,
870                                      self.max_rowsize)
871                self.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
872            else:
873                self.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
874            self.sample_fn = sample_fn
875        else:
876            self.sampler = None
877            self.sample_fn = sample_fn
878            self.source_len = min(self.source_len, self.num_samples) if self.num_samples != 0 else self.source_len
879            if not hasattr(self.source, "__iter__"):
880                # Use generator function if input callable
881                self.prepared_source = (lambda: _generator_fn(self.source, self.num_samples))
882            else:
883                # Use iterator function if input is iterable
884                # Random accessible input is also iterable
885                self.prepared_source = (lambda: _iter_fn(self.source, self.num_samples))
886
887    def parse(self, children=None):
888        self.prepare_multiprocessing()
889        if self.schema is None:
890            return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
891                                     self.sampler, self.num_parallel_workers)
892        schema = self.schema
893        if isinstance(schema, Schema):
894            schema = self.schema.cpp_schema
895        return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler,
896                                 self.num_parallel_workers)
897
898    def __validate_memory_usage(self):
899        """
900        Check memory usage when mulit-processing mode, when 85% prompt warning and 100% raise error.
901        """
902        if self.python_multiprocessing:
903            # if use num_parallel_workers is to large when python_multiprocessing=True which would cause
904            # OOM error get the num_shards
905            valid_num_shards = 1
906            if isinstance(self.sampler, samplers.DistributedSampler):
907                valid_num_shards = self.sampler.num_shards
908            elif self.num_shards is not None:
909                valid_num_shards = self.num_shards
910
911            # get process memory usage
912            process = psutil.Process(os.getpid())
913            process_memory = process.memory_info().rss
914            sys_memory_available = psutil.virtual_memory().available
915
916            total_memory_maybe_used = process_memory * self.num_parallel_workers * valid_num_shards
917            if total_memory_maybe_used / sys_memory_available > 0.85:
918                valid_num_worker = math.floor(sys_memory_available * 0.85 / valid_num_shards / process_memory)
919                valid_num_worker = 1 if valid_num_worker <= 0 else valid_num_worker
920                info = "GeneratorDataset's num_parallel_workers: {} is too large which may cause a lot of memory " \
921                       "occupation (>85%) or out of memory(OOM) during multiprocessing. Therefore, it is recommended " \
922                       "to reduce num_parallel_workers to {} or smaller.".format(self.num_parallel_workers,
923                                                                                 valid_num_worker)
924                logger.warning(info)
925
926
927class _NumpySlicesDataset:
928    """
929    Mainly for dealing with several kinds of formats of Python data, and return one row each time.
930    """
931
932    def __init__(self, data, column_list=None):
933        self.column_list = None
934        # Convert dict data into tuple
935        if isinstance(data, dict):
936            data = self.process_dict(data)
937
938        if isinstance(data, tuple):
939            self.data = data
940        else:
941            self.data = (data,)
942
943        # check whether the data length in each column is equal
944        data_len = [len(data_item) for data_item in self.data]
945        if data_len[1:] != data_len[:-1]:
946            raise ValueError("Data length in each column is not equal.")
947
948        # Init column_name
949        if column_list is not None:
950            self.column_list = column_list
951        elif self.column_list is None:
952            self.column_list = []
953            column_num = len(self.data)
954            for i in range(column_num):
955                self.column_list.append("column_" + str(i))
956
957    def __getitem__(self, index):
958        data_row = [d[index] for d in self.data]
959        data_res = tuple(data_row)
960        return data_res
961
962    def __len__(self):
963        return len(self.data[0])
964
965    def process_dict(self, input_data):
966        """
967        Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first.
968        """
969        # Convert pandas like dict(has "values" column) into General dict
970        data_keys = list(input_data.keys())
971        data_col = input_data[data_keys[0]]
972        if hasattr(data_col, "values"):
973            new_dict = {}
974            for key in data_keys:
975                item1 = input_data.pop(key)
976                new_dict[key] = item1.values
977            input_data = new_dict
978
979        # Convert the data in dict into tuple
980        data = ()
981        keys = list(input_data.keys())
982        self.column_list = keys
983        for key in keys:
984            value = input_data[key]
985            data = data + (list(value),)
986
987        return data
988
989
990class NumpySlicesDataset(GeneratorDataset):
991    """
992    Creates a dataset with given data slices, mainly for loading Python data into dataset.
993
994    The column names and column types of generated dataset depend on Python data defined by users.
995
996    Args:
997        data (Union[list, tuple, dict]) Input of given data. Supported data types include: list, tuple, dict and other
998            NumPy formats. Input data will be sliced along the first dimension and generate additional rows, if input is
999            list, there will be one column in each row, otherwise there tends to be multi columns. Large data is not
1000            recommended to be loaded in this way as data is loading into memory.
1001        column_names (list[str], optional): List of column names of the dataset. Default: ``None`` . If `column_names`
1002            is not provided, the output column names will be named as the keys of dict when the input data is a dict,
1003            otherwise they will be named like column_0, column_1 ...
1004        num_samples (int, optional): The number of samples to be included in the dataset. Default: ``None`` ,
1005            all samples.
1006        num_parallel_workers (int, optional): Number of worker subprocesses used to
1007            fetch the dataset in parallel. Default: ``1``.
1008        shuffle (bool, optional): Whether or not to perform shuffle on the dataset.
1009            Default: ``None`` , expected order behavior shown in the table below.
1010        sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset.
1011            Default: ``None`` , expected order behavior shown in the table below.
1012        num_shards (int, optional): Number of shards that the dataset will be divided into. Default: ``None`` .
1013            When this argument is specified, `num_samples` reflects the max sample number of per shard.
1014        shard_id (int, optional): The shard ID within `num_shards` . Default: ``None`` . This argument must be
1015            specified only when `num_shards` is also specified.
1016
1017    Note:
1018        - The parameters `num_samples` , `shuffle` , `num_shards` , `shard_id` can be used to control the sampler
1019          used in the dataset, and their effects when combined with parameter `sampler` are as follows.
1020
1021    .. include:: mindspore.dataset.sampler.txt
1022
1023    Raises:
1024        RuntimeError: If len of column_names does not match output len of data.
1025        ValueError: If `num_parallel_workers` exceeds the max thread numbers.
1026        ValueError: If sampler and shuffle are specified at the same time.
1027        ValueError: If sampler and sharding are specified at the same time.
1028        ValueError: If `num_shards` is specified but shard_id is None.
1029        ValueError: If shard_id is specified but `num_shards` is None.
1030        ValueError: If `shard_id` is not in range of [0, `num_shards` ).
1031
1032    Tutorial Examples:
1033        - `Load & Process Data With Dataset Pipeline
1034          <https://www.mindspore.cn/docs/en/master/api_python/samples/dataset/dataset_gallery.html>`_
1035
1036    Examples:
1037        >>> import mindspore.dataset as ds
1038        >>> # 1) Input data can be a list
1039        >>> data = [1, 2, 3]
1040        >>> dataset = ds.NumpySlicesDataset(data=data, column_names=["column_1"])
1041        >>>
1042        >>> # 2) Input data can be a dictionary, and column_names will be its keys
1043        >>> data = {"a": [1, 2], "b": [3, 4]}
1044        >>> dataset = ds.NumpySlicesDataset(data=data)
1045        >>>
1046        >>> # 3) Input data can be a tuple of lists (or NumPy arrays), each tuple element refers to data in each column
1047        >>> data = ([1, 2], [3, 4], [5, 6])
1048        >>> dataset = ds.NumpySlicesDataset(data=data, column_names=["column_1", "column_2", "column_3"])
1049        >>>
1050        >>> # 4) Load data from CSV file
1051        >>> import pandas as pd
1052        >>> df = pd.read_csv(filepath_or_buffer=csv_dataset_dir[0])
1053        >>> dataset = ds.NumpySlicesDataset(data=dict(df), shuffle=False)
1054    """
1055
1056    @check_numpyslicesdataset
1057    def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None,
1058                 num_shards=None, shard_id=None):
1059        dataset = _NumpySlicesDataset(data, column_names)
1060        super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples,
1061                         num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
1062                         num_shards=num_shards, shard_id=shard_id)
1063
1064
1065class _PaddedDataset:
1066    """
1067    Mainly for combining false samples provided by users into a dataset.
1068
1069    Args:
1070        padded_samples (list(dict)): Data provided by user to be added to the initial Dataset.
1071    """
1072
1073    def __init__(self, padded_samples):
1074        self.column_names = list(padded_samples[0].keys())
1075        self.padded_samples = padded_samples
1076
1077    def __getitem__(self, item):
1078        return (self.padded_samples[item][key] for key in self.column_names)
1079
1080    def __len__(self):
1081        return len(self.padded_samples)
1082
1083
1084class PaddedDataset(GeneratorDataset):
1085    """
1086    Creates a dataset with filler data provided by user.
1087
1088    Mainly used to add to the original dataset and assign it to the corresponding shard.
1089
1090    Args:
1091        padded_samples (list(dict)): Samples provided by user.
1092
1093    Raises:
1094        TypeError: If padded_samples is not an instance of list.
1095        TypeError: If the element of padded_samples is not an instance of dict.
1096        ValueError: If the padded_samples is empty.
1097
1098    Tutorial Examples:
1099        - `Load & Process Data With Dataset Pipeline
1100          <https://www.mindspore.cn/docs/en/master/api_python/samples/dataset/dataset_gallery.html>`_
1101
1102    Examples:
1103        >>> import mindspore.dataset as ds
1104        >>> import numpy as np
1105        >>> data = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}]
1106        >>> dataset = ds.PaddedDataset(padded_samples=data)
1107    """
1108
1109    @check_paddeddataset
1110    def __init__(self, padded_samples):
1111        dataset = _PaddedDataset(padded_samples)
1112        super().__init__(dataset, column_names=dataset.column_names, num_shards=None, shard_id=None, shuffle=False)
1113        self._dataset_size = len(dataset.padded_samples)
1114        self.padded_samples = padded_samples
1115