• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022-2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
161. This file is an abstraction of the dataset loading class. It contains
17some basic dataset operations(skip, filter, map, batch, ...).
182. Specific dataset loading classes can be found in datasets_vision.py, datasets_text.py,
19datasets_audio.py, datasets_standard_format.py and dataets_user_defined.py files.
20    datasets_vision.py: contains vision dataset loading classes.
21    datasets_text.py: contains text dataset loading classes.
22    datasets_audio.py: contains audio dataset loading classes.
23    datasets_standard_format.py: contains standard format loading classes which
24                                 any other kinds of datasets can be converted to.
25    dataets_user_defined.py: contains basic classes that help users to define
26                             flexible ways to load dataset.
27"""
28import atexit
29import glob
30import json
31import os
32import queue
33import signal
34import stat
35import subprocess
36import warnings
37
38import gc
39import time
40import uuid
41import multiprocessing
42from enum import Enum
43from importlib import import_module
44import sys
45import threading
46
47import copy
48import weakref
49import platform
50import psutil
51
52import mindspore._c_dataengine as cde
53from mindspore._c_expression import typing
54
55from mindspore import log as logger
56from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _get_ps_context,\
57                                           _enable_distributed_mindrt
58from mindspore.dataset.engine.offload import GetOffloadModel
59
60import mindspore.dataset.transforms.c_transforms as c_transforms
61import mindspore.dataset.transforms.py_transforms as py_transforms
62import mindspore.dataset.transforms as transforms
63from mindspore.dataset.text.utils import SentencePieceModel, DE_C_INTER_SENTENCEPIECE_MODE
64from mindspore.parallel._utils import _get_device_num
65from mindspore.dataset.debug import DebugHook
66
67from mindspore.dataset.engine import samplers
68from .iterators import DictIterator, TupleIterator, DummyIterator, check_iterator_cleanup, _set_iterator_cleanup, \
69    ITERATORS_LIST, _unset_iterator_cleanup
70from .queue import _SharedQueue, _Queue
71from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
72    check_rename, check_device_send, check_take, check_output_shape, check_project, \
73    check_sync_wait, check_zip_dataset, check_add_column, check_concat, check_split, check_bucket_batch_by_length, \
74    check_save, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_padded_batch, \
75    check_total_batch
76from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
77    get_enable_watchdog, get_seed, set_seed, get_debug_mode, get_multiprocessing_timeout_interval, _get_debug_hook_list
78from ..core.datatypes import mstype_to_detype
79from ..core.validator_helpers import replace_none
80from ..core.py_util_helpers import ExceptionHandler
81from ..transforms.py_transforms_util import FuncWrapper, Implementation
82from ..vision.transforms import ToNumpy
83from ...mindrecord.config import _get_enc_key, _get_enc_mode, _get_hash_mode, encrypt, append_hash_to_file
84
85try:
86    context = import_module("mindspore.context")
87except ModuleNotFoundError:
88    context = None
89
90if platform.system().lower() == "darwin" and multiprocessing.get_start_method() != "fork":
91    multiprocessing.set_start_method("fork", True)
92
93OffloadToManualOffloadMode = {
94    None: cde.ManualOffloadMode.UNSPECIFIED,
95    False: cde.ManualOffloadMode.DISABLED,
96    True: cde.ManualOffloadMode.ENABLED
97}
98
99_train_dataset = None
100
101
102def _set_training_dataset(dataset):
103    """
104    Set the dataset to be used when training recovery has occurred.
105
106    Args:
107        dataset: the training dataset or iterator
108    """
109    global _train_dataset
110    _train_dataset = dataset
111
112
113def _get_training_dataset():
114    """
115    Get the dataset to be used when training recovery has occurred.
116
117    Returns:
118        training dataset/iterator
119    """
120    return _train_dataset
121
122
123def _reset_training_dataset(global_step, dataset_size):
124    """
125    Reset the training dataset to the given global step.
126
127    Args:
128        global_step (int): Number of global steps that have completed training.
129            Dataset will provide data from its next step after reset.
130        dataset_size (int): Number of steps per epoch.
131    """
132    dataset = _get_training_dataset()
133    if dataset is not None:
134        dataset._reset(global_step, dataset_size)  # pylint: disable=protected-access
135    else:
136        raise RuntimeError("Training dataset is not set.")
137
138
139class Shuffle(str, Enum):
140    """Specify the shuffle mode.
141
142    - ``Shuffle.GLOBAL`` : Shuffle both the files and samples.
143    - ``Shuffle.FILES`` : Shuffle files only.
144    - ``Shuffle.INFILE`` : Shuffle data within each file.
145    """
146    GLOBAL: str = "global"
147    FILES: str = "files"
148    INFILE: str = "infile"
149
150
151ShuffleToShuffleMode = {Shuffle.FILES: cde.ShuffleMode.FILES,
152                        Shuffle.GLOBAL: cde.ShuffleMode.GLOBAL,
153                        Shuffle.INFILE: cde.ShuffleMode.INFILE}
154
155
156def shuffle_to_shuffle_mode(shuffle):
157    """
158    Shuffle Enum to Shuffle Mode
159
160    Args:
161        shuffle (Shuffle): shuffle flag to shuffle mode in C layer
162
163    Returns:
164        ShuffleMode, shuffle mode
165    """
166    shuffle_mode = cde.ShuffleMode.GLOBAL  # Global shuffle
167    if not isinstance(shuffle, Shuffle):
168        if shuffle is None or shuffle:
169            shuffle_mode = cde.ShuffleMode.GLOBAL  # Global shuffle
170        else:
171            shuffle_mode = cde.ShuffleMode.FALSE  # No shuffle
172    else:
173        shuffle_mode = ShuffleToShuffleMode[shuffle]
174    return shuffle_mode
175
176
177def shuffle_to_bool(shuffle):
178    """
179    Shuffle Enum to bool
180
181    Args:
182        shuffle (Shuffle): shuffle flag to bool
183
184    Returns:
185        bool, True / False
186    """
187    if shuffle is not None and not isinstance(shuffle, (bool, Shuffle)):
188        raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like 'Shuffle.GLOBAL' or "
189                        "'Shuffle.FILES' or 'Shuffle.INFILE'.")
190
191    shuffle_bool = True
192    if not isinstance(shuffle, Shuffle):
193        if shuffle is None:
194            shuffle_bool = None
195        elif shuffle:
196            shuffle_bool = True
197        else:
198            shuffle_bool = False
199    else:
200        shuffle_bool = True
201    return shuffle_bool
202
203
204@check_zip
205def zip(datasets):
206    """
207    Zip the datasets in the input tuple of datasets.
208
209    Args:
210        datasets (tuple[Dataset]): A tuple of datasets to be zipped together.
211            The number of datasets must be more than 1.
212
213    Returns:
214        Dataset, a new dataset with the above operation applied.
215
216    Raises:
217        ValueError: If the number of datasets is 1.
218        TypeError: If datasets is not a tuple.
219
220    Examples:
221            >>> # Create a dataset which is the combination of dataset_1 and dataset_2
222            >>> import mindspore.dataset as ds
223            >>>
224            >>> dataset_1 = ds.GeneratorDataset([1], "column1")
225            >>> dataset_2 = ds.GeneratorDataset([2], "column2")
226            >>> dataset = ds.zip((dataset_1, dataset_2))
227    """
228    if len(datasets) <= 1:
229        raise ValueError(
230            "Can't zip empty or just one dataset!")
231    for dataset in datasets:
232        if not isinstance(dataset, Dataset):
233            raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset))
234    return ZipDataset(datasets)
235
236
237def _get_operator_process():
238    """
239    Inner implemented method, mainly for passing sub-process id in C layer
240
241    Returns:
242         dict, mapping dict of operation id and corresponding process id.
243    """
244    global _OP_PROCESS
245    process_info = _OP_PROCESS
246    op_process = dict()
247    keys = process_info.keys()
248    fetched_all = True
249    for key in keys:
250        try:
251            op_process[key] = list(process_info[key][1])
252            item_full = (len(process_info[key][1]) == process_info[key][0])
253        except KeyError as err:
254            raise err
255        fetched_all = fetched_all and item_full
256    return op_process, fetched_all
257
258
259def _set_dataset_permissions(file_name, num_files):
260    """
261    set saved dataset files' permissions to 600
262    the rule of dataset filenames should be the same as those in C++.
263    """
264    num_digits = len(str(num_files - 1))
265    if num_files == 1:
266        paths = [file_name]
267    else:
268        paths = ["{}{}".format(file_name, str(x).rjust(num_digits, '0')) for x in range(num_files)]
269
270    for item in paths:
271        if os.path.exists(item):
272            os.chmod(item, stat.S_IRUSR | stat.S_IWUSR)
273            index_file = item + ".db"
274            if os.path.exists(index_file):
275                os.chmod(index_file, stat.S_IRUSR | stat.S_IWUSR)
276
277
278class Dataset:
279    """
280    Abstract class to represent a dataset in DataEngine's data pipeline.
281
282    This class is the base class of SourceDataset and Dataset, and represents
283    a node in the data flow graph.
284                                     Dataset
285           -----------------------------------------------------------
286           |                  |                   |                  |
287    VisionBaseDataset    TextBaseDataset    AudioBaseDataset         |
288           -                  -                   -                  |
289           |                  |                   |                  |
290           ----------------------------------------                  |
291                      UnionBaseDataset                               |
292                                                                     |
293                                                               SourceDataset
294                                                                     -
295                                                                     |
296                                                              MappableDataset
297
298    DatasetOperation: MapDataset(UnionBaseDataset)
299                      BatchDataset(UnionBaseDataset)
300                      PaddedBatchDataset(UnionBaseDataset)
301                      BucketBatchByLengthDataset(UnionBaseDataset)
302                      ShuffleDataset(UnionBaseDataset)
303                      FilterDataset(UnionBaseDataset)
304                      RepeatDataset(UnionBaseDataset)
305                      SkipDataset(UnionBaseDataset)
306                      TakeDataset(UnionBaseDataset)
307                      ZipDataset(UnionBaseDataset)
308                      ConcatDataset(UnionBaseDataset)
309                      RenameDataset(UnionBaseDataset)
310                      ProjectDataset(UnionBaseDataset)
311                      SyncWaitDataset(UnionBaseDataset)
312
313    Impl Dataset - vision:       ImageFolderDataset(MappableDataset, VisionBaseDataset)
314                                 USPSDataset(SourceDataset, VisionBaseDataset)
315    Impl Dataset - text:         TextFileDataset(SourceDataset, TextBaseDataset)
316                                 YahooAnswersDataset(SourceDataset, TextBaseDataset)
317    Impl Dataset - audio:        LJSpeechDataset(MappableDataset, AudioBaseDataset)
318                                 TedliumDataset(MappableDataset, AudioBaseDataset)
319    Impl Dataset - standard:     MindDataset(MappableDataset, UnionBaseDataset)
320                                 TFRecordDataset(SourceDataset, UnionBaseDataset)
321    Impl Dataset - user defined: GeneratorDataset(MappableDataset, UnionBaseDataset)
322                                 NumpySlicesDataset(GeneratorDataset)
323
324    Args:
325        num_parallel_workers (int, optional): Number of workers to process the dataset in parallel.
326            Default: ``None``.
327    """
328
329    def __init__(self, children=None, num_parallel_workers=None, cache=None):
330        # Note: children and parent are internal variables, not recommended for external using.
331        self.children = replace_none(children, [])
332        if isinstance(self.children, tuple):
333            self.children = list(self.children)
334        if not isinstance(self.children, list):
335            self.children = [self.children]
336
337        self.parent = []
338        for child in self.children:
339            child.parent.append(weakref.ref(self))
340        self.num_parallel_workers = num_parallel_workers
341        self.cache = cache
342
343        self._device_iter = 0
344        self._input_indexs = ()
345        self.saved_output_types = None
346        self.saved_output_shapes = None
347        self.estimated_output_shapes = None
348        self.runtime_context = None
349        self._col_names = None
350        self.dataset_size = None
351        self._batch_size = None
352        self._num_classes = None
353        self._repeat_count = None
354        self._class_indexing = None
355        self._sync = False
356        self._global_step = None
357
358    @staticmethod
359    def _get_operator_id(dataset):
360        """
361        Internal method to iterate the tree and obtain op_id of each operation.
362
363        Returns:
364            Dataset, the root dataset of the tree.
365        """
366        op_name = dict()
367        generator_process = dict()
368        op_name[str(dataset)] = 0
369        op_id = 1
370
371        def process_name(datasets, operator_id):
372            if not datasets:
373                return 0
374            temp = []
375            for item in datasets:
376                for d in item.children:
377                    temp.append(d)
378                    op_name[str(d)] = operator_id
379
380                    from mindspore.dataset.engine.datasets_user_defined import GeneratorDataset
381                    if isinstance(d, GeneratorDataset) and d.sample_fn and d.sample_fn.pids:
382                        generator_process[operator_id] = [d.num_parallel_workers, set(d.sample_fn.pids)]
383
384            operator_id = operator_id + 1
385            return process_name(temp, operator_id)
386
387        process_name([dataset], op_id)
388        if generator_process:
389            global _OP_PROCESS
390            _OP_PROCESS.update(generator_process)
391        return op_name
392
393    def create_ir_tree(self, getter_mode=False):
394        """
395        Internal method to build an IR tree.
396
397        Args:
398            getter_mode (bool, optional): Whether to build IR tree in pull mode. Default: ``False``.
399
400        Returns:
401            Union[DatasetNode, Dataset], the root node of the IR tree and the root dataset of the IR tree.
402        """
403        parent = self.parent
404        self.parent = []
405        dataset = copy.deepcopy(self)
406        global _OP_NAME
407        _OP_NAME = Dataset._get_operator_id(dataset)
408        ir_tree = dataset.parse_tree(getter_mode)
409        self.parent = parent
410        _init_device_info()
411        return ir_tree, dataset
412
413    def parse_tree(self, getter_mode=False):
414        """
415        Internal method to parse the API tree into an IR tree.
416
417        Args:
418            getter_mode (bool, optional): Whether to build IR tree in pull mode. Default: ``False``.
419
420        Returns:
421            DatasetNode, the root node of the IR tree.
422        """
423        if len(self.parent) > 1:
424            raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
425        ir_children = [d.parse_tree(getter_mode) for d in self.children]
426        # Bootstrap can only be performed on a copy of the original dataset node.
427        # Bootstrap on original dataset node will make all iterators share the same process pool
428        self.pre_parse(getter_mode)
429        self.iterator_bootstrap()
430        ir_node = self.parse(ir_children)
431        ir_node = self.post_parse(ir_node)
432        return ir_node
433
434    def __safe_deepcopy__(self, memodict, exclude=()):
435        if id(self) in memodict:
436            return memodict[id(self)]
437        cls = self.__class__
438        new_op = cls.__new__(cls)
439        memodict[id(self)] = new_op
440        for arg, value in self.__dict__.items():
441            if arg in exclude:
442                setattr(new_op, arg, value)
443            else:
444                try:
445                    setattr(new_op, arg, copy.deepcopy(value, memodict))
446                except TypeError:
447                    setattr(new_op, arg, value)
448        return new_op
449
450    @staticmethod
451    def _noop_mode():
452        if _is_role_sched():
453            return True
454        return False
455
456    def iterator_bootstrap(self):
457        pass
458
459    def __add__(self, datasets):
460        return self.concat(datasets)
461
462    def to_json(self, filename=""):
463        """
464        Serialize a pipeline into JSON string and dump into file if filename is provided.
465
466        Args:
467            filename (str): filename of JSON file to be saved as. Default: ``""``.
468
469        Returns:
470            str, JSON string of the pipeline.
471
472        Examples:
473            >>> import mindspore.dataset as ds
474            >>> mnist_dataset_dir = "/path/to/mnist_dataset_directory"
475            >>> dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir)
476            >>> dataset_json = dataset.to_json("/path/to/mnist_dataset_pipeline.json")
477        """
478        ir_tree, _ = self.create_ir_tree()
479        return json.loads(ir_tree.to_json(filename))
480
481    @check_bucket_batch_by_length
482    def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes, element_length_function=None,
483                               pad_info=None, pad_to_bucket_boundary=False, drop_remainder=False):
484        """
485        Bucket elements according to their lengths. Each bucket will be padded and batched when
486        they are full.
487
488        A length function is called on each row in the dataset. The row is then
489        bucketed based on its length and bucket boundaries. When a bucket reaches its
490        corresponding size specified in bucket_batch_sizes, the entire bucket will be
491        padded according to pad_info, and then form a batch.
492
493        Refer to the following figure for the execution process:
494
495        .. image:: bucket_batch_by_length_en.png
496
497        Args:
498            column_names (list[str]): Columns passed to element_length_function.
499            bucket_boundaries (list[int]): A list consisting of the upper boundaries
500                of the buckets. Must be strictly increasing. If there are n boundaries,
501                n+1 buckets are created: One bucket for [0, bucket_boundaries[0]), one
502                bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each
503                0<i<n-1, and the last bucket for [bucket_boundaries[n-1], inf).
504            bucket_batch_sizes (list[int]): A list consisting of the batch sizes for
505                each bucket. Must contain len(bucket_boundaries)+1 elements.
506            element_length_function (Callable, optional): A function that takes in
507                M arguments where M = len(column_names) and returns an integer. If no value
508                provided, parameter M the len(column_names) must be 1, and the size of the first
509                dimension of that column will be taken as the length. Default: ``None``.
510            pad_info (dict, optional): The information about how to batch each column. The key
511                corresponds to the column name, and the value must be a tuple of 2 elements.
512                The first element corresponds to the shape to pad to, and the second
513                element corresponds to the value to pad with. If a column is not
514                specified, then that column will be padded to the longest in the current
515                batch, and 0 will be used as the padding value. Any None dimensions will
516                be padded to the longest in the current batch, unless if
517                `pad_to_bucket_boundary` is ``True``. If no padding is wanted, set `pad_info`
518                to ``None``. Default: ``None``.
519            pad_to_bucket_boundary (bool, optional): If ``True``, will pad each None
520                dimension in `pad_info` to the bucket_boundary minus 1. If there are any
521                elements that fall into the last bucket, an error will occur.
522                Default: ``False``.
523            drop_remainder (bool, optional): If ``True``, will drop the last batch for each
524                bucket if it is not a full batch. Default: ``False``.
525
526        Returns:
527            Dataset, a new dataset with the above operation applied.
528
529        Examples:
530            >>> # Create a dataset where certain counts rows are combined into a batch
531            >>> # and drops the last incomplete batch if there is one.
532            >>> import mindspore.dataset as ds
533            >>> import numpy as np
534            >>> def generate_2_columns(n):
535            ...     for i in range(n):
536            ...         yield (np.array([i]), np.array([j for j in range(i + 1)]))
537            >>>
538            >>> column_names = ["col1", "col2"]
539            >>> dataset = ds.GeneratorDataset(generate_2_columns(8), column_names)
540            >>> bucket_boundaries = [5, 10]
541            >>> bucket_batch_sizes = [2, 1, 1]
542            >>> element_length_function = (lambda col1, col2: max(len(col1), len(col2)))
543            >>> # Will pad col2 to shape [bucket_boundaries[i]] where i is the
544            >>> # index of the bucket that is currently being batched.
545            >>> pad_info = {"col2": ([None], -1)}
546            >>> pad_to_bucket_boundary = True
547            >>> dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
548            ...                                          bucket_batch_sizes,
549            ...                                          element_length_function, pad_info,
550            ...                                          pad_to_bucket_boundary)
551        """
552        return BucketBatchByLengthDataset(self, column_names, bucket_boundaries, bucket_batch_sizes,
553                                          element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder)
554
555    @check_batch
556    def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, **kwargs):
557        """
558        Combine batch_size number of consecutive rows into batch which apply per_batch_map to the samples first.
559
560        For any column, all the elements within that column must have the same shape.
561
562        Refer to the following figure for the execution process:
563
564        .. image:: batch_en.png
565
566        Note:
567            The order of using repeat and batch reflects the number of batches and per_batch_map.
568            It is recommended that the repeat operation applied after the batch operation finished.
569
570        Args:
571            batch_size (Union[int, Callable]): The number of rows each batch is created with. An
572                int or callable object which takes exactly 1 parameter, BatchInfo.
573            drop_remainder (bool, optional): Determines whether or not to drop the last block
574                whose data row number is less than batch size. Default: ``False`` . If ``True`` ,
575                and if there are less than `batch_size` rows available to make the last batch,
576                then those rows will be dropped and not propagated to the child node.
577            num_parallel_workers (int, optional): Number of workers(threads) to process the dataset in parallel.
578                Default: ``None`` .
579            **kwargs:
580
581                - per_batch_map (Callable[[List[numpy.ndarray], ..., List[numpy.ndarray], BatchInfo], \
582                  (List[numpy.ndarray], ..., List[numpy.ndarray])], optional): Per batch map callable.
583                  Default: ``None``.
584                  A callable which takes (List[numpy.ndarray], ..., List[numpy.ndarray], BatchInfo) as input parameters.
585                  Each list[numpy.ndarray] represents a batch of numpy.ndarray on a given column. The number of lists
586                  should match with the number of entries in input_columns. The last parameter of the callable should
587                  always be a BatchInfo object. Per_batch_map should return
588                  (list[numpy.ndarray], list[numpy.ndarray], ...). The length of each list in output should be the same
589                  as the input. output_columns is required if the number of output lists is different from input.
590
591                - input_columns (Union[str, list[str]], optional): List of names of the input columns. The size of
592                  the list should match with signature of `per_batch_map` callable. Default: ``None`` .
593
594                - output_columns (Union[str, list[str]], optional): List of names assigned to the columns
595                  outputted by the last operation. This parameter is mandatory if len(input_columns) !=
596                  len(output_columns). The size of this list must match the number of output
597                  columns of the last operation. Default: ``None`` , output columns will have the same
598                  name as the input columns, i.e., the columns will be replaced.
599
600                - python_multiprocessing (bool, optional): Parallelize Python function `per_batch_map` with
601                  multi-processing or multi-threading mode, ``True`` means multi-processing,
602                  ``False`` means multi-threading If `per_batch_map` is a I/O bound task, use
603                  multi-threading mode. If `per_batch_map` is a CPU bound task, it is recommended to use
604                  multi-processing mode. Default: ``False`` , use python multi-threading mode.
605
606                - max_rowsize(Union[int, list[int]], optional): Maximum size of row in MB that is used for shared memory
607                  allocation to copy data between processes, the total occupied shared memory will increase as
608                  ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set
609                  to -1, shared memory will be dynamically allocated with the actual size of data. This is only used if
610                  ``python_multiprocessing`` is set to True. If it is an int value, it represents
611                  ``input_columns`` and ``output_columns`` use this value as the unit to create shared memory.
612                  If it is a list, the first element represents the ``input_columns`` use this value as the unit to
613                  create shared memory, and the second element represents ``output_columns`` use this value as the unit
614                  to create shared memory. Default: 16.
615
616        Returns:
617            Dataset, a new dataset with the above operation applied.
618
619        Examples:
620            >>> # 1) Create a dataset where every 5 rows are combined into a batch
621            >>> # and drops the last incomplete batch if there is one.
622            >>> import mindspore.dataset as ds
623            >>> from PIL import Image
624            >>>
625            >>> cifar10_dataset_dir = "/path/to/cifar10_dataset_directory"
626            >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, num_samples=10)
627            >>> dataset = dataset.batch(5, True)
628            >>>
629            >>> # 2) resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25)
630            >>> def np_resize(col, BatchInfo):
631            ...     output = col.copy()
632            ...     s = (BatchInfo.get_batch_num() + 1) ** 2
633            ...     index = 0
634            ...     for c in col:
635            ...         img = Image.fromarray(c.astype('uint8')).convert('RGB')
636            ...         img = img.resize((s, s))
637            ...         output[index] = np.array(img)
638            ...         index += 1
639            ...     return (output,)
640            >>> dataset = dataset.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize)
641            >>>
642            >>> # 3) Create a dataset where its batch size is dynamic
643            >>> # Define a callable batch size function and let batch size increase 1 each time.
644            >>> def add_one(BatchInfo):
645            ...     return BatchInfo.get_batch_num() + 1
646            >>> dataset = dataset.batch(batch_size=add_one, drop_remainder=True)
647        """
648        return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, **kwargs)
649
650    @check_padded_batch
651    def padded_batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, pad_info=None):
652        """
653        Combine batch_size number of consecutive rows into batch which apply pad_info to the samples first.
654
655        Refer to the following figure for the execution process:
656
657        .. image:: padded_batch_en.png
658
659        Note:
660            The order of using repeat and padded_batch reflects the number of batches.
661            It is recommended that the repeat operation applied after the padded_batch operation finished.
662
663        Args:
664            batch_size (Union[int, Callable]): The number of rows each batch is created with. An
665                int or callable object which takes exactly 1 parameter, BatchInfo.
666            drop_remainder (bool, optional): Determines whether or not to drop the last block
667                whose data row number is less than batch size. Default: ``False``. If ``True``, and if there
668                are less than batch_size rows available to make the last batch, then those rows will
669                be dropped and not propagated to the child node.
670            num_parallel_workers (int, optional): Number of workers(threads) to process the dataset in parallel.
671                Default: ``None``.
672            pad_info (dict, optional): The pad information about how to batch each column. The key
673                corresponds to the column name, and the value must be a tuple of 2 elements.
674                The first element corresponds to the shape to pad to, and the second
675                element corresponds to the value to pad with. If a column is not
676                specified, then that column will be padded to the longest in the current
677                batch, and 0 will be used as the padding value. If ``pad_info={"col1": ([224, 224], 0)}``,
678                expand the data column named ``col1`` to shape (224, 224), and fill in the missing values with 0.
679                If ``pad_info={}``, all samples in the batch will be filled to the shape with the largest sample
680                in the current batch. If ``pad_info={"col1": (None, 100)}``, all samples in the batch will be filled
681                to the shape with the largest sample in the current batch, and fill in the missing values with 100.
682                If no padding is wanted, set `pad_info` to ``None``. Default: ``None``.
683
684        Returns:
685            Dataset, a new dataset with the above operation applied.
686
687        Examples:
688            >>> # 1) Pad every sample to the largest sample's shape and batch the samples
689            >>> import mindspore.dataset as ds
690            >>> dataset = ds.NumpySlicesDataset([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]], "column1")
691            >>> dataset = dataset.padded_batch(2, True, pad_info={})
692            >>>
693            >>> # 2) Create a dataset where every 3 rows are combined into a batch
694            >>> # and drops the last incomplete batch if there is one.
695            >>> dataset = ds.NumpySlicesDataset([i for i in range(10)], "column1")
696            >>> dataset = dataset.padded_batch(3, True)
697            >>>
698            >>> # 3) Create a dataset where its batch size is dynamic
699            >>> # Define a callable batch size function and let batch size increase 1 each time.
700            >>> def add_one(BatchInfo):
701            ...     return BatchInfo.get_batch_num() + 1
702            >>> dataset = dataset.padded_batch(batch_size=add_one, drop_remainder=True)
703        """
704        return PaddedBatchDataset(self, batch_size, drop_remainder, num_parallel_workers, pad_info)
705
706    @check_sync_wait
707    def sync_wait(self, condition_name, num_batch=1, callback=None):
708        """
709        Add a blocking condition to the input Dataset and a synchronize action will be applied.
710
711        Args:
712            condition_name (str): The condition name that is used to toggle sending next row.
713            num_batch (int): the number of batches without blocking at the start of each epoch.
714                Default: ``1``.
715            callback (function): The callback function that will be invoked when sync_update is called.
716                Default: ``None``.
717
718        Returns:
719            Dataset, a new dataset with the above operation applied.
720
721        Raises:
722            RuntimeError: If condition name already exists.
723
724        Examples:
725            >>> import mindspore.dataset as ds
726            >>> import numpy as np
727            >>> def gen():
728            ...     for i in range(100):
729            ...         yield (np.array(i),)
730            >>>
731            >>> class Augment:
732            ...     def __init__(self, loss):
733            ...         self.loss = loss
734            ...
735            ...     def preprocess(self, input_):
736            ...         return input_
737            ...
738            ...     def update(self, data):
739            ...         self.loss = data["loss"]
740            >>>
741            >>> batch_size = 4
742            >>> dataset = ds.GeneratorDataset(gen, column_names=["input"])
743            >>>
744            >>> aug = Augment(0)
745            >>> dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
746            >>> dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
747            >>> dataset = dataset.batch(batch_size)
748            >>> count = 0
749            >>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
750            ...     assert data["input"][0] == count
751            ...     count += batch_size
752            ...     data = {"loss": count}
753            ...     dataset.sync_update(condition_name="policy", data=data)
754        """
755        return SyncWaitDataset(self, condition_name, num_batch, callback)
756
757    @check_shuffle
758    def shuffle(self, buffer_size):
759        """
760        Shuffle the dataset by creating a cache with the size of `buffer_size` .
761
762        1. Make a shuffle buffer that contains the first `buffer_size` rows.
763        2. Randomly select an element from the shuffle buffer to be the next row
764           propagated to the child node.
765        3. Get the next row (if any) from the parent node and put it in the shuffle buffer.
766        4. Repeat steps 2 and 3 until there are no more rows left in the shuffle buffer.
767
768        A random seed can be provided to be used on the first epoch via `dataset.config.set_seed` . In every subsequent
769        epoch, the seed is changed to a new one, randomly generated value.
770
771        Args:
772            buffer_size (int): The size of the buffer (must be larger than 1) for
773                shuffling. Setting `buffer_size` equal to the number of rows in the entire
774                dataset will result in a global shuffle.
775
776        Returns:
777            Dataset, a new dataset with the above operation applied.
778
779        Raises:
780            RuntimeError: If exist sync operations before shuffle.
781
782        Examples:
783            >>> import mindspore.dataset as ds
784            >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
785            >>>
786            >>> # Optionally set the seed for fixed randomness
787            >>> ds.config.set_seed(58)
788            >>>
789            >>> # Create a shuffled dataset using a shuffle buffer of size 4
790            >>> dataset = dataset.shuffle(4)
791        """
792        return ShuffleDataset(self, buffer_size)
793
794    def flat_map(self, func):
795        """
796        Map `func` to each row in dataset and flatten the result.
797
798        Args:
799            func (function): A function that must take one `numpy.ndarray` as an argument and
800                return a `Dataset` .
801
802        Returns:
803            Dataset, a new dataset with the above operation applied.
804
805        Examples:
806            >>> import mindspore.dataset as ds
807            >>> # 1) flat_map on one column dataset
808            >>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]], shuffle=False)
809            >>>
810            >>> def repeat(array):
811            ...     # create a NumpySlicesDataset with the array
812            ...     data = ds.NumpySlicesDataset(array, shuffle=False)
813            ...     # repeat the dataset twice
814            ...     data = data.repeat(2)
815            ...     return data
816            >>>
817            >>> dataset = dataset.flat_map(repeat)
818            >>> # [0, 1, 0, 1, 2, 3, 2, 3]
819            >>>
820            >>> # 2) flat_map on multi column dataset
821            >>> dataset = ds.NumpySlicesDataset(([[0, 1], [2, 3]], [[0, -1], [-2, -3]]), shuffle=False)
822            >>>
823            >>> def plus_and_minus(col1, col2):
824            ...     # apply different methods on columns
825            ...     data = ds.NumpySlicesDataset((col1 + 1, col2 - 1), shuffle=False)
826            ...     return data
827            >>>
828            >>> dataset = dataset.flat_map(plus_and_minus)
829            >>> # ([1, 2, 3, 4], [-1, -2, -3, -4])
830
831        Raises:
832            TypeError: If `func` is not a function.
833            TypeError: If `func` doesn't return a Dataset.
834        """
835        dataset = None
836        if not hasattr(func, '__call__'):
837            logger.critical("func must be a function.")
838            raise TypeError("func must be a function.")
839
840        for row_data in self.create_tuple_iterator(num_epochs=1, output_numpy=True):
841            if dataset is None:
842                dataset = func(*row_data)
843            else:
844                dataset += func(*row_data)
845
846        if not isinstance(dataset, Dataset):
847            logger.critical("flat_map must return a Dataset object.")
848            raise TypeError("flat_map must return a Dataset object.")
849        return dataset
850
851    @check_map
852    def map(self, operations, input_columns=None, output_columns=None, column_order=None,
853            num_parallel_workers=None, **kwargs):
854        """
855        Apply each operation in operations to this dataset.
856
857        Each operation will be passed one or more columns from the dataset as input, and one or
858        more columns will be outputted. The first operation will be passed the columns specified
859        in input_columns as input. If there is more than one operation in operations, the outputted
860        columns of the previous operation are used as the input columns for the next operation.
861
862        The columns outputted by the very last operation will be assigned names specified by
863        `output_columns` , and if not specified, the column name of output column is same as that of `input_columns` .
864
865        - If you use transformations (
866          `vision transform <https://mindspore.cn/docs/en/master/api_python/mindspore.\
867          dataset.transforms.html#module-mindspore.dataset.vision>`_ ,
868          `nlp transform <https://mindspore.cn/docs/en/master/api_python/mindspore.\
869          dataset.transforms.html#module-mindspore.dataset.text>`_ ,
870          `audio transform <https://mindspore.cn/docs/en/master/api_python/mindspore.\
871          dataset.transforms.html#module-mindspore.dataset.audio>`_ )
872          provided by mindspore dataset, please use the following parameters:
873
874          .. image:: map_parameter_en.png
875
876        - If you use user-defined transform as PyFunc (Python Func), please use the following parameters:
877
878          .. image:: map_parameter_pyfunc_en.png
879
880        Args:
881            operations (Union[list[TensorOperation], list[functions]]): List of operations to be
882                applied on the dataset. Operations are applied in the order they appear in this list.
883            input_columns (Union[str, list[str]], optional): List of the names of the columns that will be passed to
884                the first operation as input. The size of this list must match the number of
885                input columns expected by the first operation. Default: ``None``, the first
886                operation will be passed however many columns that are required, starting from
887                the first column.
888            output_columns (Union[str, list[str]], optional): List of names assigned to the columns outputted by
889                the last operation. This parameter is mandatory if len(input_columns) !=
890                len(output_columns). The size of this list must match the number of output
891                columns of the last operation. Default: ``None``, output columns will have the same
892                name as the input columns, i.e., the columns will be replaced.
893            num_parallel_workers (int, optional): Number of threads used to process the dataset in
894                parallel. Default: ``None``, the value from the configuration will be used.
895            **kwargs:
896
897                - python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker processes.
898                  This option could be beneficial if the Python operation is computational heavy. Default: ``False``.
899
900                - max_rowsize (Union[int, list[int]], optional): Maximum size of row in MB that is used for shared
901                  memory allocation to copy data between processes, the total occupied shared memory will increase as
902                  ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set
903                  to -1, shared memory will be dynamically allocated with the actual size of data. This is only used if
904                  ``python_multiprocessing`` is set to True. If it is an int value, it represents
905                  ``input_columns`` and ``output_columns`` use this value as the unit to create shared memory.
906                  If it is a list, the first element represents the ``input_columns`` use this value as the unit to
907                  create shared memory, and the second element represents ``output_columns`` use this value as the unit
908                  to create shared memory. Default: 16.
909
910                - cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
911                  Default: ``None``, which means no cache is used.
912
913                - callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called.
914                  Default: ``None``.
915
916                - offload (bool, optional): Flag to indicate whether offload is used. Default: ``None``.
917
918        Note:
919            - Input `operations` accepts TensorOperations defined in mindspore.dataset part, plus user-defined
920              Python functions (PyFuncs).
921            - Do not add network computing operators from mindspore.nn and mindspore.ops or others into this
922              `operations` .
923
924        Returns:
925            Dataset, a new dataset with the above operation applied.
926
927        Examples:
928            >>> import mindspore.dataset as ds
929            >>> import mindspore.dataset.vision as vision
930            >>> # dataset is an instance of Dataset which has 2 columns, "image" and "label".
931            >>> # image is of type bytes type which can be decoded to RGB
932            >>> # label is of type int32
933            >>> cifar10_dataset_dir = "/path/to/cifar10_dataset_directory"
934            >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir)
935            >>>
936            >>> # Define two operations, where each operation accepts 1 input column and outputs 1 column.
937            >>> decode_op = vision.Decode(to_pil=False)
938            >>> random_jitter_op = vision.RandomColorAdjust(brightness=(0.8, 0.8), contrast=(1, 1),
939            ...                                             saturation=(1, 1), hue=(0, 0))
940            >>>
941            >>> # 1) Simple map example.
942            >>>
943            >>> # Apply decode_op on column "image".
944            >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"])
945            >>>
946            >>> # Decode and rename column "image" to "decoded_image".
947            >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"], output_columns=["decoded_image"])
948            >>>
949            >>> # A simple example for user defined python function transform.
950            >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"])
951            >>> dataset = dataset.map(operations=[(lambda x: x - 1)], input_columns=["data"])
952            >>>
953            >>> # 2) Map example with more than one operation.
954            >>>
955            >>> # Create a dataset where the images are decoded, then randomly color jittered.
956            >>> # decode_op takes column "image" as input and outputs one column. The column
957            >>> # outputted by decode_op is passed as input to random_jitter_op.
958            >>> # random_jitter_op will output one column. Column "image" will be replaced by
959            >>> # the column outputted by random_jitter_op (the very last operation). All other
960            >>> # columns are unchanged.
961            >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"])
962            >>>
963            >>> # Rename the column outputted by random_jitter_op to "image_mapped".
964            >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"],
965            ...                       output_columns=["image_mapped"])
966            >>>
967            >>> # Map with multiple operations using pyfunc and rename column's name
968            >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"])
969            >>> dataset = dataset.map(operations=[(lambda x: x * x), (lambda x: x - 1)], input_columns=["data"],
970            ...                                   output_columns=["data_mapped"])
971            >>>
972            >>> # 3) Example where number of input columns is not equal to number of output columns.
973            >>>
974            >>> # operations[0] is a lambda that takes 2 columns as input and outputs 3 columns.
975            >>> # operations[1] is a lambda that takes 3 columns as input and outputs 1 column.
976            >>> # operations[2] is a lambda that takes 1 column as input and outputs 4 columns.
977            >>> #
978            >>> # Note: The number of output columns of operation[i] must equal the number of
979            >>> # input columns of operation[i+1]. Otherwise, this map call will also result
980            >>> # in an error.
981            >>> operations = [(lambda x, y: (x, x + y, x + y + 1)),
982            ...               (lambda x, y, z: x * y * z),
983            ...               (lambda x: (x % 2, x % 3, x % 5, x % 7))]
984            >>> dataset = ds.NumpySlicesDataset(data=([[0, 1, 2]], [[3, 4, 5]]), column_names=["x", "y"])
985            >>> dataset = dataset.map(operations, input_columns=["x", "y"],
986            ...                       output_columns=["mod2", "mod3", "mod5", "mod7"])
987        """
988        if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True:
989            num_parallel_workers = 1
990            logger.warning(
991                "Input 'operations' of 'map' includes network computing operators like in mindspore.nn, mindspore.ops, "
992                "mindspore.numpy module and etc, which do not support multi-thread compiling, recommend to replace it "
993                "with python implemented operator like numpy etc. Here decrease 'num_parallel_workers' into 1.")
994
995        return MapDataset(self, operations, input_columns, output_columns, num_parallel_workers, **kwargs)
996
997    @check_filter
998    def filter(self, predicate, input_columns=None, num_parallel_workers=None):
999        """
1000        Filter dataset by prediction.
1001
1002        Args:
1003            predicate (callable): Python callable which returns a boolean value. If False then filter the element.
1004            input_columns (Union[str, list[str]], optional): List of names of the input columns. If not provided
1005                or provided with ``None``, the predicate will be applied on all columns in the dataset.
1006                Default: ``None``.
1007            num_parallel_workers (int, optional): Number of workers to process the dataset
1008                in parallel. Default: ``None``.
1009
1010        Returns:
1011            Dataset, a new dataset with the above operation applied.
1012
1013        Examples:
1014            >>> # generator data(0 ~ 19)
1015            >>> # filter the data that greater than or equal to 11
1016            >>> import mindspore.dataset as ds
1017            >>> dataset = ds.GeneratorDataset([i for i in range(20)], "data")
1018            >>> dataset = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"])
1019        """
1020        return FilterDataset(self, predicate, input_columns, num_parallel_workers)
1021
1022    @check_repeat
1023    def repeat(self, count=None):
1024        """
1025        Repeat this dataset `count` times. Repeat infinitely if the `count` is ``None`` or ``-1``.
1026
1027        Note:
1028            The order of using repeat and batch reflects the number of batches. It is recommended that
1029            the repeat operation is used after the batch operation.
1030
1031        Args:
1032            count (int): Number of times the dataset is going to be repeated. Default: ``None``.
1033
1034        Returns:
1035            Dataset, a new dataset with the above operation applied.
1036
1037        Examples:
1038            >>> import mindspore.dataset as ds
1039            >>>
1040            >>> # Create a dataset with 10 elements
1041            >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1042            >>> ori_size = dataset.get_dataset_size()
1043            >>>
1044            >>> # Repeat the dataset 50 times.
1045            >>> dataset = dataset.repeat(50)
1046            >>> repeated_size = dataset.get_dataset_size()
1047            >>> print("ori_size", ori_size, ", repeated_size", repeated_size)
1048            ori_size 10 , repeated_size 500
1049            >>>
1050            >>> # Since the original dataset size is less than batch_size, thus no data is returned
1051            >>> dataset1 = ds.GeneratorDataset([i for i in range(10)], "column1")
1052            >>> dataset1 = dataset1.batch(batch_size=20, drop_remainder=True)
1053            >>> dataset1 = dataset1.repeat(6)
1054            >>>
1055            >>> # Repeat the original dataset to 60 elements, thus 3 batches are returned
1056            >>> dataset2 = ds.GeneratorDataset([i for i in range(10)], "column1")
1057            >>> dataset2 = dataset2.repeat(6)
1058            >>> dataset2 = dataset2.batch(batch_size=20, drop_remainder=True)
1059            >>> print("dataset1 size", dataset1.get_dataset_size(), ", dataset2 size", dataset2.get_dataset_size())
1060            dataset1 size 0 , dataset2 size 3
1061        """
1062        return RepeatDataset(self, count)
1063
1064    @check_skip
1065    def skip(self, count):
1066        """
1067        Skip the first N elements of this dataset.
1068
1069        Args:
1070            count (int): Number of elements in the dataset to be skipped.
1071
1072        Returns:
1073            Dataset, a new dataset with the above operation applied.
1074
1075        Examples:
1076            >>> import mindspore.dataset as ds
1077            >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1078            >>> # Skip first 3 elements of dataset and retain 7 elements.
1079            >>> dataset = dataset.skip(3)
1080        """
1081        return SkipDataset(self, count)
1082
1083    @check_take
1084    def take(self, count=-1):
1085        """
1086        Take the first specified number of samples from the dataset.
1087
1088        Args:
1089            count (int, optional): The desired number of samples to take. If the value exceeds
1090                the total number of samples in the dataset, all data will be returned.
1091                Default: ``-1`` , will return all data.
1092
1093        Note:
1094            When there are operations that will change the number of samples of the dataset in
1095            the data pipeline, the location of the `take` operation can change its effect.
1096            For example, `batch` operation will combine the successive samples of the specified
1097            `batch_size` into 1 sample, so `.batch(batch_size).take(1)` will be equivalent to
1098            `.take(batch_size).batch(batch_size)`.
1099
1100        Returns:
1101            Dataset, a new dataset with the above operation applied.
1102
1103        Examples:
1104            >>> import mindspore.dataset as ds
1105            >>> mnist_dataset_dir = "/path/to/mnist_dataset_directory"
1106            >>> dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir)
1107            >>> # Take 50 samples from MNIST dataset.
1108            >>> dataset = dataset.take(50)
1109        """
1110        return TakeDataset(self, count)
1111
1112    def _get_absolute_split_sizes(self, sizes):
1113        """
1114        Internal method called by split to calculate absolute split sizes and to
1115        do some error checking after calculating absolute split sizes.
1116
1117        Returns:
1118            int, absolute split sizes of the dataset.
1119        """
1120        # Call get_dataset_size here and check input here because
1121        # don't want to call this once in check_split and another time in
1122        # here again
1123        dataset_size = self.get_dataset_size()
1124
1125        if dataset_size is None or dataset_size <= 0:
1126            raise RuntimeError("dataset_size is unknown, unable to split.")
1127
1128        if not isinstance(sizes, list):
1129            raise RuntimeError("sizes must be a list.")
1130
1131        all_int = all(isinstance(item, int) for item in sizes)
1132        if all_int:
1133            sizes_sum = sum(sizes)
1134            if sizes_sum != dataset_size:
1135                raise RuntimeError("Sum of split sizes {} is not equal to dataset size {}."
1136                                   .format(sizes_sum, dataset_size))
1137            return sizes
1138
1139        absolute_sizes = []
1140        for item in sizes:
1141            absolute_size = int(round(item * dataset_size))
1142            if absolute_size == 0:
1143                raise RuntimeError("Split percentage {} is too small.".format(item))
1144            absolute_sizes.append(absolute_size)
1145
1146        absolute_sizes_sum = sum(absolute_sizes)
1147
1148        # if we still need more rows, give them to the first split.
1149        # if we have too many rows, remove the extras from the first split that has
1150        # enough rows.
1151        size_difference = int(dataset_size - absolute_sizes_sum)
1152        if size_difference > 0:
1153            absolute_sizes[0] += size_difference
1154        else:
1155            for i, _ in enumerate(absolute_sizes):
1156                if absolute_sizes[i] + size_difference > 0:
1157                    absolute_sizes[i] += size_difference
1158                    break
1159
1160        if sum(absolute_sizes) != dataset_size:
1161            raise RuntimeError("Sum of calculated split sizes {} is not equal to dataset size {}."
1162                               .format(absolute_sizes_sum, dataset_size))
1163
1164        return absolute_sizes
1165
1166    @check_split
1167    def split(self, sizes, randomize=True):
1168        """
1169        Split the dataset into smaller, non-overlapping datasets.
1170
1171        Args:
1172            sizes (Union[list[int], list[float]]): If a list of integers [s1, s2, …, sn] is
1173                provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
1174                respectively. If the sum of all input sizes does not equal the original dataset size, an
1175                error will throw.
1176                If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1
1177                and must sum to 1, otherwise an error will throw. The dataset will be split into n
1178                Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the
1179                original dataset.
1180                If after rounding:
1181
1182                - Any size equals 0, an error will occur.
1183                - The sum of split sizes < K, the difference of K - sigma(round(fi * k)) will be added to the first
1184                  split.
1185                - The sum of split sizes > K, the difference of sigma(round(fi * K)) - K will be removed from the first
1186                  large enough split such that it will have at least 1 row after removing the difference.
1187
1188            randomize (bool, optional): Determines whether or not to split the data randomly. Default: ``True``.
1189                If True, the data will be randomly split. Otherwise, each split will be created with
1190                consecutive rows from the dataset.
1191
1192        Note:
1193            1. Dataset cannot be sharded if split is going to be called.
1194            2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
1195               Shuffling the dataset may not be deterministic, which means the data in each split
1196               will be different in each epoch.
1197
1198        Returns:
1199            Tuple[Dataset], a tuple of new datasets split from the original one.
1200
1201        Raises:
1202            RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
1203            RuntimeError: If `sizes` is list of integers and sum of all elements in sizes does not
1204                equal the dataset size.
1205            RuntimeError: If `sizes` is list of float and there is a split with size 0 after calculations.
1206            RuntimeError: If the dataset is sharded prior to calling split.
1207            ValueError: If `sizes` is list of float and not all floats are between 0 and 1, or if the
1208                floats don't sum to 1.
1209
1210        Examples:
1211            >>> # Split the data into train part and test part.
1212            >>> import mindspore.dataset as ds
1213            >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1214            >>> train_dataset, test_dataset = dataset.split([0.9, 0.1])
1215        """
1216        if self.is_shuffled():
1217            logger.warning("Dataset is shuffled before split.")
1218
1219        if self.is_sharded():
1220            raise RuntimeError("Dataset should not be sharded before split.")
1221
1222        absolute_sizes = self._get_absolute_split_sizes(sizes)
1223        splits = []
1224        rows_to_skip = 0
1225        for size in absolute_sizes:
1226            ds = copy.deepcopy(self)
1227            if randomize:
1228                # want to shuffle the same way every epoch before split
1229                # in alter_tree, shuffle buffer is minimum 10000, so use 10000 here
1230                ds = ds.shuffle(10000)
1231                ds.reshuffle_each_epoch = False
1232
1233            if rows_to_skip > 0:
1234                ds = ds.skip(rows_to_skip)
1235
1236            ds = ds.take(size)
1237            splits.append(ds)
1238
1239            rows_to_skip += size
1240
1241        return tuple(splits)
1242
1243    @check_zip_dataset
1244    def zip(self, datasets):
1245        """
1246        Zip the datasets in the sense of input tuple of datasets. Columns in the input datasets must have different
1247        name.
1248
1249        Args:
1250            datasets (Union[Dataset, tuple[Dataset]]): A tuple of datasets or a single class Dataset
1251                to be zipped together with this dataset.
1252
1253        Returns:
1254            Dataset, a new dataset with the above operation applied.
1255
1256        Raises:
1257            TypeError: The parameter is not dataset object or tuple of dataset objects.
1258
1259        Examples:
1260            >>> # Create a dataset which is the combination of dataset_1 and dataset_2
1261            >>> import mindspore.dataset as ds
1262            >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1")
1263            >>> dataset_2 = ds.GeneratorDataset([1, 2, 3], "column2")
1264            >>> dataset = dataset_1.zip(dataset_2)
1265        """
1266        if isinstance(datasets, tuple):
1267            datasets = (self, *datasets)
1268        elif isinstance(datasets, Dataset):
1269            datasets = (self, datasets)
1270        else:
1271            raise TypeError("Invalid datasets, expected Dataset object or tuple of Dataset, but got %s!" % datasets)
1272        return ZipDataset(datasets)
1273
1274    @check_concat
1275    def concat(self, datasets):
1276        """
1277        Concatenate the dataset objects in the input list.
1278        Performing "+" operation on dataset objects can achieve the same effect.
1279
1280        For a dataset concatenated by many other dataset objects, it returns the data in the order of
1281        datasets passed in. If you want to change the data order(such as random selection from each dataset
1282        instead of in sequence), apply `use_sampler` method on the concatenated dataset object.
1283        Currently `use_sampler` supports `dataset.DistributedSampler` for sharding selection from each dataset
1284        or `dataset.RandomSampler` for random selection from each dataset, see examples below.
1285
1286        Note:
1287            The column name, and rank and type of the column data must be the same in the input datasets.
1288
1289        Args:
1290            datasets (Union[list, Dataset]): A list of datasets or a single class Dataset
1291                to be concatenated together with this dataset.
1292
1293        Returns:
1294            Dataset, a new dataset with the above operation applied.
1295
1296        Examples:
1297            >>> import mindspore.dataset as ds
1298            >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1", shuffle=False)
1299            >>> dataset_2 = ds.GeneratorDataset([4, 5, 6], "column1", shuffle=False)
1300            >>>
1301            >>> # Create a dataset by concatenating dataset_1 and dataset_2 with "+" operator
1302            >>> dataset = dataset_1 + dataset_2
1303            >>> # Create a dataset by concatenating dataset_1 and dataset_2 with concat operation
1304            >>> dataset = dataset_1.concat(dataset_2)
1305            >>>
1306            >>> # Check the data order of dataset
1307            >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1", shuffle=False)
1308            >>> dataset_2 = ds.GeneratorDataset([4, 5, 6], "column1", shuffle=False)
1309            >>> dataset = dataset_1 + dataset_2
1310            >>> result = list(dataset)
1311            >>> # [[Tensor(shape=[], dtype=Int64, value= 1)], [Tensor(shape=[], dtype=Int64, value= 2)],
1312            >>> #  [Tensor(shape=[], dtype=Int64, value= 3)], [Tensor(shape=[], dtype=Int64, value= 4)],
1313            >>> #  [Tensor(shape=[], dtype=Int64, value= 5)], [Tensor(shape=[], dtype=Int64, value= 6)]]
1314            >>>
1315            >>> # Change the data order of concatenated dataset with sharding selection
1316            >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1", shuffle=False)
1317            >>> dataset_2 = ds.GeneratorDataset([4, 5, 6], "column1", shuffle=False)
1318            >>> dataset = dataset_1.concat(dataset_2)
1319            >>> dataset.use_sampler(ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False))
1320            >>> result = list(dataset)
1321            >>> # [[Tensor(shape=[], dtype=Int64, value= 2)], [Tensor(shape=[], dtype=Int64, value= 4)],
1322            >>> #  [Tensor(shape=[], dtype=Int64, value= 6)]]
1323            >>>
1324            >>> # Change the data order of concatenated dataset with random selection
1325            >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1", shuffle=False)
1326            >>> dataset_2 = ds.GeneratorDataset([4, 5, 6], "column1", shuffle=False)
1327            >>> dataset = dataset_1.concat(dataset_2)
1328            >>> dataset.use_sampler(ds.RandomSampler())
1329            >>> result = list(dataset)
1330            >>> # [[Tensor(shape=[], dtype=Int64, value= 1)], [Tensor(shape=[], dtype=Int64, value= 4)],
1331            >>> #  [Tensor(shape=[], dtype=Int64, value= 2)], [Tensor(shape=[], dtype=Int64, value= 5)],
1332            >>> #  [Tensor(shape=[], dtype=Int64, value= 6)], [Tensor(shape=[], dtype=Int64, value= 3)]]
1333        """
1334        if isinstance(datasets, Dataset):
1335            datasets = [self] + [datasets]
1336        elif isinstance(datasets, list):
1337            datasets = [self] + datasets
1338        else:
1339            raise TypeError("Invalid datasets, expected Dataset object or list of Dataset, but got %s!" % datasets)
1340        return ConcatDataset(datasets)
1341
1342    @check_rename
1343    def rename(self, input_columns, output_columns):
1344        """
1345        Rename the columns in input datasets.
1346
1347        Args:
1348            input_columns (Union[str, list[str]]): List of names of the input columns.
1349            output_columns (Union[str, list[str]]): List of names of the output columns.
1350
1351        Returns:
1352            Dataset, a new dataset with the above operation applied.
1353
1354        Examples:
1355            >>> import mindspore.dataset as ds
1356            >>> input_columns = ["input_col1", "input_col2", "input_col3"]
1357            >>> output_columns = ["output_col1", "output_col2", "output_col3"]
1358            >>>
1359            >>> # Create a dataset with 3 columns
1360            >>> dataset = ds.GeneratorDataset([(1, 2, 3), (3, 4, 5), (5, 6, 7)], column_names=input_columns)
1361            >>>
1362            >>> # Rename "input_col1" to "output_col1", "input_col2" to "output_col2", "input_col3" to "output_col3"
1363            >>> dataset = dataset.rename(input_columns=input_columns, output_columns=output_columns)
1364        """
1365
1366        return RenameDataset(self, input_columns, output_columns)
1367
1368    @check_project
1369    def project(self, columns):
1370        """
1371        The specified columns will be selected from the dataset and passed into
1372        the pipeline with the order specified. The other columns are discarded.
1373
1374        Args:
1375            columns(Union[str, list[str]]): List of names of the columns to project.
1376
1377        Returns:
1378            Dataset, a new dataset with the above operation applied.
1379
1380        Examples:
1381            >>> import mindspore.dataset as ds
1382            >>> # Create a dataset with 3 columns
1383            >>> input_columns = ["column1", "column2", "column3"]
1384            >>> dataset = ds.GeneratorDataset([(1, 2, 3), (3, 4, 5), (5, 6, 7)], column_names=input_columns)
1385            >>>
1386            >>> columns_to_project = ["column3", "column1", "column2"]
1387            >>> # in that order, regardless of the original order of columns.
1388            >>> dataset = dataset.project(columns=columns_to_project)
1389        """
1390
1391        return ProjectDataset(self, columns)
1392
1393    def apply(self, apply_func):
1394        """
1395        Apply a function in this dataset.
1396
1397        Args:
1398            apply_func (function): A function that must take one `Dataset` as an argument and
1399                                   return a preprocessed `Dataset` .
1400
1401        Returns:
1402            Dataset, a new dataset with the above operation applied.
1403
1404        Examples:
1405            >>> import mindspore.dataset as ds
1406            >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1407            >>>
1408            >>> # Declare an apply_func function which returns a Dataset object
1409            >>> def apply_func(data):
1410            ...     data = data.batch(2)
1411            ...     return data
1412            >>>
1413            >>> # Use apply to call apply_func
1414            >>> dataset = dataset.apply(apply_func)
1415
1416        Raises:
1417            TypeError: If apply_func is not a function.
1418            TypeError: If apply_func doesn't return a Dataset.
1419        """
1420
1421        if not hasattr(apply_func, '__call__'):
1422            raise TypeError("apply_func must be a function.")
1423
1424        dataset = apply_func(self)
1425        if not isinstance(dataset, Dataset):
1426            raise TypeError("apply_func must return a dataset.")
1427        return dataset
1428
1429    @check_device_send
1430    def device_que(self, send_epoch_end=True, create_data_info_queue=False, queue_name=""):
1431        """
1432        Return a transferred Dataset that transfers data through a device.
1433
1434        Args:
1435            send_epoch_end (bool, optional): Whether to send end of sequence to device or not.
1436                Default: ``True``.
1437            create_data_info_queue (bool, optional): Whether to create queue which stores
1438                types and shapes of data or not. Default: ``False``.
1439            queue_name (str, optional): Name of queue which connects dataset processing and model
1440                computing. Default: ``""``.
1441
1442        Note:
1443            If device is Ascend, features of data will be transferred one by one. The limitation
1444            of data transmission per time is 256M.
1445
1446        Returns:
1447            Dataset, a new dataset with the above operation applied.
1448
1449        Examples:
1450            >>> import mindspore.dataset as ds
1451            >>> import time
1452            >>>
1453            >>> data = ds.TFRecordDataset('/path/to/TF_FILES', '/path/to/TF_SCHEMA_FILE', shuffle=ds.Shuffle.FILES)
1454            >>> data = data.device_que()
1455            >>> data.send()
1456            >>> time.sleep(0.1)
1457            >>> data.stop_send()
1458        """
1459        return TransferDataset(self, send_epoch_end, create_data_info_queue, queue_name)
1460
1461    @check_save
1462    def save(self, file_name, num_files=1, file_type='mindrecord'):
1463        """
1464        Save the dynamic data processed by the dataset pipeline in common dataset format.
1465        Supported dataset formats: ``'mindrecord'`` only. And you can use
1466        :class:`mindspore.dataset.MindDataset` API to read the saved file(s).
1467
1468        Implicit type casting exists when saving data as ``'mindrecord'`` . The transform table shows how to do
1469        type casting.
1470
1471        .. list-table:: Implicit Type Casting when Saving as `mindrecord`
1472           :widths: 25 25 50
1473           :header-rows: 1
1474
1475           * - Type in `dataset`
1476             - Type in `mindrecord`
1477             - Details
1478           * - bool
1479             - int32
1480             - transform to int32
1481           * - int8
1482             - int32
1483             -
1484           * - uint8
1485             - int32
1486             -
1487           * - int16
1488             - int32
1489             -
1490           * - uint16
1491             - int32
1492             -
1493           * - int32
1494             - int32
1495             -
1496           * - uint32
1497             - int64
1498             -
1499           * - int64
1500             - int64
1501             -
1502           * - uint64
1503             - int64
1504             - Maybe reverse
1505           * - float16
1506             - float32
1507             -
1508           * - float32
1509             - float32
1510             -
1511           * - float64
1512             - float64
1513             -
1514           * - string
1515             - string
1516             - Multi-dimensional string not supported
1517           * - bytes
1518             - bytes
1519             - Multi-dimensional bytes not supported
1520
1521        Note:
1522            1. To save the samples in order, set dataset's `shuffle` to ``False`` and `num_files` to ``1``.
1523            2. Before calling the function, do not use batch operation, repeat operation or data augmentation operations
1524               with random attribute in map operation.
1525            3. When array dimension is variable, one-dimensional arrays or
1526               multi-dimensional arrays with variable dimension 0 are supported.
1527            4. MindRecord does not support multi-dimensional string or multi-dimensional bytes.
1528
1529        Args:
1530            file_name (str): Path to dataset file.
1531            num_files (int, optional): Number of dataset files. Default: ``1`` .
1532            file_type (str, optional): Dataset format. Default: ``'mindrecord'`` .
1533
1534        Examples:
1535            >>> import mindspore.dataset as ds
1536            >>> import numpy as np
1537            >>>
1538            >>> def generator_1d():
1539            ...     for i in range(10):
1540            ...         yield (np.array([i]),)
1541            >>>
1542            >>> # apply dataset operations
1543            >>> d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
1544            >>> d1.save('/path/to/save_file')
1545        """
1546        if (_get_enc_key() is not None or _get_hash_mode() is not None) and num_files > 1:
1547            raise RuntimeError("When encode mode or hash check is enabled, " +
1548                               "the automatic sharding function is unavailable.")
1549
1550        ir_tree, api_tree = self.create_ir_tree()
1551
1552        runtime_context = cde.PythonRuntimeContext()
1553        runtime_context.Init()
1554        consumer = cde.PythonSaveToDisk(file_name, num_files, file_type)
1555        consumer.Init(ir_tree)
1556        runtime_context.AssignConsumer(consumer)
1557
1558        consumer.Save()
1559
1560        if _get_hash_mode() is not None:
1561            append_hash_to_file(file_name)
1562            append_hash_to_file(file_name + ".db")
1563
1564        if _get_enc_key() is not None:
1565            encrypt(file_name, _get_enc_key(), _get_enc_mode())
1566            encrypt(file_name + ".db", _get_enc_key(), _get_enc_mode())
1567
1568        _set_dataset_permissions(file_name, num_files)
1569        del api_tree
1570
1571    @check_tuple_iterator
1572    def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True):
1573        """
1574        Create an iterator over the dataset that yields samples of type list, whose elements are
1575        the data for each column.
1576
1577        Args:
1578            columns (list[str], optional): Specify the output columns and the order.
1579                Default: ``None``, keep all the output columns and their original order.
1580            num_epochs (int, optional): The number of epochs to iterate over the entire dataset.
1581                Default: ``-1`` , the dataset can be iterated indefinitely.
1582            output_numpy (bool, optional): Whether to keep the output data as NumPy ndarray, or
1583                convert it to Tensor. Default: ``False`` .
1584            do_copy (bool, optional): Whether to copy the data when converting output to Tensor,
1585                or reuse the buffer for better performance, only works when `output_numpy` is ``False`` .
1586                Default: ``True`` .
1587
1588        Returns:
1589            Iterator, a dataset iterator that yields samples of type list.
1590
1591        Examples:
1592            >>> import mindspore.dataset as ds
1593            >>>
1594            >>> dataset = ds.GeneratorDataset([i for i in range(10)], "data")
1595            >>> num_epochs = 3
1596            >>> iterator = dataset.create_tuple_iterator(num_epochs=num_epochs)
1597            >>> for epoch in range(num_epochs):
1598            ...     for item in iterator:
1599            ...         # output is of type tuple
1600            ...         print(type(item))
1601            ...         break
1602            ...     break
1603            <class 'list'>
1604        """
1605        if output_numpy is None:
1606            output_numpy = False
1607
1608        if Dataset._noop_mode():
1609            return DummyIterator(self, 'tuple', output_numpy)
1610        return TupleIterator(self, columns, num_epochs, output_numpy, do_copy)
1611
1612    @check_dict_iterator
1613    def create_dict_iterator(self, num_epochs=-1, output_numpy=False, do_copy=True):
1614        """
1615        Create an iterator over the dataset that yields samples of type dict,
1616        while the key is the column name and the value is the data.
1617
1618        Args:
1619            num_epochs (int, optional): The number of epochs to iterate over the entire dataset.
1620                Default: ``-1`` , the dataset can be iterated indefinitely.
1621            output_numpy (bool, optional): Whether to keep the output data as NumPy ndarray, or
1622                convert it to Tensor. Default: ``False`` .
1623            do_copy (bool, optional): Whether to copy the data when converting output to Tensor,
1624                or reuse the buffer for better performance, only works when `output_numpy` is ``False`` .
1625                Default: ``True`` .
1626
1627        Returns:
1628            Iterator, a dataset iterator that yields samples of type dict.
1629
1630        Examples:
1631            >>> import mindspore.dataset as ds
1632            >>>
1633            >>> dataset = ds.GeneratorDataset([i for i in range(10)], "data")
1634            >>> num_epochs = 3
1635            >>> iterator = dataset.create_dict_iterator(num_epochs=num_epochs)
1636            >>> for epoch in range(num_epochs):
1637            ...     for item in iterator:
1638            ...         # output is of type dict
1639            ...         print(type(item))
1640            ...         break
1641            ...     break
1642            <class 'dict'>
1643        """
1644        if output_numpy is None:
1645            output_numpy = False
1646
1647        if Dataset._noop_mode():
1648            return DummyIterator(self, 'dict', output_numpy)
1649        return DictIterator(self, num_epochs, output_numpy, do_copy)
1650
1651    def __iter__(self):
1652        """Create an iterator over the dataset."""
1653        return self.create_tuple_iterator(num_epochs=1)
1654
1655    @property
1656    def input_indexs(self):
1657        """
1658        Get the column index, which represents the corresponding relationship between the data column order
1659        and the network when using the sink mode.
1660
1661        Returns:
1662            int, tuple of the input index information.
1663
1664        Examples:
1665            >>> import mindspore.dataset as ds
1666            >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1667            >>> # set input_indexs
1668            >>> dataset.input_indexs = 10
1669            >>> print(dataset.input_indexs)
1670            10
1671        """
1672        if self._input_indexs != ():
1673            return self._input_indexs
1674
1675        # find input_indexes of children
1676        children_input_index = [child.input_indexs for child in self.children]
1677
1678        # in case of more than one child, return the first input_indexes
1679        for cix in children_input_index:
1680            if cix != ():
1681                return cix
1682
1683        # if all children's input_indexes are () or the node is a leaf
1684        return self._input_indexs
1685
1686    @input_indexs.setter
1687    def input_indexs(self, value):
1688        self._input_indexs = value
1689
1690    def copy_batch_size(self, value):
1691        self._batch_size = value
1692
1693    def _init_tree_getters(self, getter_mode=True):
1694        """
1695        Get pipeline information.
1696
1697        Args:
1698            getter_mode (bool, optional): Whether to build IR tree in pull mode. Default: ``True``.
1699        """
1700        ir_tree, api_tree = self.create_ir_tree(getter_mode)
1701
1702        runtime_context = cde.PythonRuntimeContext()
1703        runtime_context.Init()
1704        getter = cde.TreeGetters()
1705        getter.Init(ir_tree)
1706        runtime_context.AssignConsumer(getter)
1707        return getter, runtime_context, api_tree
1708
1709    def __init_size_getter(self):
1710        """
1711        Get pipeline information.
1712        """
1713        ir_tree, api_tree = self.create_ir_tree()
1714
1715        runtime_context = cde.PythonRuntimeContext()
1716        runtime_context.Init()
1717        getter = cde.DatasetSizeGetters()
1718        getter.Init(ir_tree)
1719        runtime_context.AssignConsumer(getter)
1720        return getter, runtime_context, api_tree
1721
1722    def get_col_names(self):
1723        """
1724        Return the names of the columns in dataset.
1725
1726        Returns:
1727            list, list of column names in the dataset.
1728
1729        Examples:
1730            >>> import mindspore.dataset as ds
1731            >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1732            >>> col_names = dataset.get_col_names()
1733            >>> print(col_names)
1734            ['column1']
1735
1736        """
1737        if self._col_names is None:
1738            runtime_getter = self._init_tree_getters()
1739            self._col_names = runtime_getter[0].GetColumnNames()
1740
1741        return self._col_names
1742
1743    @check_output_shape
1744    def output_shapes(self, estimate=False):
1745        """
1746        Get the shapes of output data.
1747
1748        Args:
1749            estimate (bool): If `estimate` is ``False`` , will return the shapes of first data row.
1750                Otherwise, will iterate the whole dataset and return the estimated shapes of data row,
1751                where dynamic shape is marked as None (used in dynamic data shapes scenario).
1752                Default: ``False`` .
1753
1754        Returns:
1755            list, list of shapes of each column.
1756
1757        Examples:
1758            >>> import mindspore.dataset as ds
1759            >>> import numpy as np
1760            >>>
1761            >>> def generator1():
1762            ...     for i in range(1, 100):
1763            ...         yield np.ones((16, 83, 83)), np.array([i])
1764            >>>
1765            >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
1766            >>> output_shapes = dataset.output_shapes()
1767            >>> print(output_shapes)
1768            [[16, 83, 83], [1]]
1769        """
1770        # cache single shape
1771        if not estimate and self.saved_output_shapes is not None:
1772            return self.saved_output_shapes
1773        # cache estimate shape
1774        if estimate and self.estimated_output_shapes is not None:
1775            return self.estimated_output_shapes
1776
1777        # We have a hang problem when two-level pipeline with multiprocessing, we need to extend the life cycle
1778        # of runtime_context. We found this hang problem only occur on output_types and output_shapes.
1779        runtime_getter = self._init_tree_getters()
1780        self.runtime_context = runtime_getter[1]
1781        api_tree = runtime_getter[2]
1782        output_shapes = runtime_getter[0].GetOutputShapes(estimate)
1783        del api_tree
1784        # Need to terminate the runtime context to avoid the occasional hang problem for
1785        # Python (with multiprocessing enabled) in sink mode.
1786        self.runtime_context.Terminate()
1787        del self.runtime_context
1788
1789        if estimate:
1790            self.estimated_output_shapes = output_shapes
1791        else:
1792            self.saved_output_shapes = output_shapes
1793        return output_shapes
1794
1795    def output_types(self):
1796        """
1797        Get the types of output data.
1798
1799        Returns:
1800            list, list of data types.
1801
1802        Examples:
1803            >>> import mindspore.dataset as ds
1804            >>> import numpy as np
1805            >>>
1806            >>> def generator1():
1807            ...     for i in range(1, 100):
1808            ...         yield np.ones((16, 83, 83)).astype(np.float32), np.array([i]).astype(np.int32)
1809            >>>
1810            >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
1811            >>> output_types = dataset.output_types()
1812            >>> print(output_types)
1813            [dtype('float32'), dtype('int32')]
1814        """
1815        if self.saved_output_types is None:
1816            runtime_getter = self._init_tree_getters()
1817            # We have a hang problem when two-level pipeline with multiprocessing, we need to extend the life cycle
1818            # of runtime_context. We found this hang problem only occur on output_types and output_shapes.
1819            self.runtime_context = runtime_getter[1]
1820            api_tree = runtime_getter[2]
1821            self.saved_output_types = runtime_getter[0].GetOutputTypes()
1822            del api_tree
1823            # Need to terminate the runtime context to avoid the occasional hang problem for
1824            # Python (with multiprocessing enabled) in sink mode.
1825            self.runtime_context.Terminate()
1826            del self.runtime_context
1827        return self.saved_output_types
1828
1829    def get_dataset_size(self):
1830        """
1831        Return the number of batches in an epoch.
1832
1833        Returns:
1834            int, number of batches.
1835
1836        Examples:
1837            >>> import mindspore.dataset as ds
1838            >>> import numpy as np
1839            >>>
1840            >>> # A generator return 66 samples
1841            >>> def generator1():
1842            ...     for i in range(66):
1843            ...         yield np.ones((16, 83, 83)), np.array([i])
1844            >>>
1845            >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
1846            >>> dataset_size = dataset.get_dataset_size()
1847            >>> print(dataset_size)
1848            66
1849        """
1850        if self.dataset_size is None:
1851            runtime_getter = self.__init_size_getter()
1852            self.dataset_size = runtime_getter[0].GetDatasetSize(False)
1853            if self.dataset_size == 0:
1854                logger.warning("Got 0 sample from dataset pipeline, check if drop all data or load dataset fail.")
1855
1856        return self.dataset_size
1857
1858    def num_classes(self):
1859        """
1860        Get the number of classes in a dataset.
1861
1862        Returns:
1863            int, number of classes.
1864
1865        Examples:
1866            >>> import mindspore.dataset as ds
1867            >>> # Read image files
1868            >>> image_folder_dataset_dir = "/path/to/image_folder_dataset_directory"
1869            >>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir)
1870            >>> # Check how many classes exist in image folder
1871            >>> num_classes = dataset.num_classes()
1872        """
1873        if self._num_classes is None:
1874            runtime_getter = self._init_tree_getters()
1875            self._num_classes = runtime_getter[0].GetNumClasses()
1876
1877        if self._num_classes == -1:
1878            return None
1879        return self._num_classes
1880
1881    def get_sync_notifiers(self):
1882        if self.children:
1883            return self.children[0].get_sync_notifiers()
1884        return {}
1885
1886    def disable_sync(self):
1887        if self.children:
1888            return self.children[0].disable_sync()
1889        return {}
1890
1891    def is_sync(self):
1892        if self.children:
1893            return self.children[0].is_sync()
1894        return False
1895
1896    def sync_update(self, condition_name, num_batch=None, data=None):
1897        """
1898        Release a blocking condition and trigger callback with given data.
1899
1900        Args:
1901            condition_name (str): The condition name that is used to toggle sending next row.
1902            num_batch (Union[int, None]): The number of batches (rows) that are released.
1903                When `num_batch` is ``None``, it will default to the number specified by the
1904                `sync_wait` operation. Default: ``None``.
1905            data (Any): The data passed to the callback, user defined. Default: ``None``.
1906
1907        Examples:
1908            >>> import numpy as np
1909            >>> import mindspore.dataset as ds
1910            >>>
1911            >>> def gen():
1912            ...     for i in range(100):
1913            ...         yield (np.array(i),)
1914            >>>
1915            >>> class Augment:
1916            ...     def __init__(self, loss):
1917            ...         self.loss = loss
1918            ...
1919            ...     def preprocess(self, input_):
1920            ...         return input_
1921            ...
1922            ...     def update(self, data):
1923            ...         self.loss = data["loss"]
1924            >>>
1925            >>> batch_size = 10
1926            >>> dataset = ds.GeneratorDataset(gen, column_names=["input"])
1927            >>> aug = Augment(0)
1928            >>> dataset = dataset.sync_wait(condition_name='', num_batch=1)
1929            >>> dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
1930            >>> dataset = dataset.batch(batch_size)
1931            >>>
1932            >>> count = 0
1933            >>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
1934            ...     count += 1
1935            ...     data = {"loss": count}
1936            ...     dataset.sync_update(condition_name="", data=data)
1937        """
1938        if (not isinstance(num_batch, int) and num_batch is not None) or \
1939                (isinstance(num_batch, int) and num_batch <= 0):
1940            # throwing exception, disable all sync_wait in pipeline
1941            self.disable_sync()
1942            raise RuntimeError("Sync_update batch size can only be positive integer, got : {}.".format(num_batch))
1943        notifiers_dict = self.get_sync_notifiers()
1944        if not isinstance(condition_name, str):
1945            raise TypeError("Argument condition_name with value {} is not of type str, but got {}."
1946                            .format(condition_name, type(condition_name)))
1947        if condition_name not in notifiers_dict:
1948            # throwing exception, disable all sync_wait in pipeline
1949            self.disable_sync()
1950            raise RuntimeError("Condition name not found.")
1951        if num_batch is not None:
1952            num_batch *= self.get_batch_size()
1953        notifiers_dict[condition_name](num_batch, data)
1954
1955    def get_batch_size(self):
1956        """
1957        Return the size of batch.
1958
1959        Returns:
1960            int, the batch size of data.
1961
1962        Examples:
1963            >>> import mindspore.dataset as ds
1964            >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1965            >>> dataset = dataset.batch(2)
1966            >>> batch_size = dataset.get_batch_size()
1967            >>> print(batch_size)
1968            2
1969        """
1970        if self._batch_size is None:
1971            runtime_getter = self._init_tree_getters()
1972            self._batch_size = runtime_getter[0].GetBatchSize()
1973        if self._batch_size is None:
1974            self._batch_size = 1
1975        return self._batch_size
1976
1977    def get_repeat_count(self):
1978        """
1979        Get the replication times in RepeatDataset. Default: ``1`` .
1980
1981        Returns:
1982            int, the count of repeat.
1983
1984        Examples:
1985            >>> import mindspore.dataset as ds
1986            >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1987            >>> dataset = dataset.repeat(5)
1988            >>> repeat_count = dataset.get_repeat_count()
1989            >>> print(repeat_count)
1990            5
1991        """
1992        if self._repeat_count is None:
1993            runtime_getter = self._init_tree_getters()
1994            self._repeat_count = runtime_getter[0].GetRepeatCount()
1995        if self._repeat_count is None:
1996            self._repeat_count = 1
1997        return self._repeat_count
1998
1999    def get_class_indexing(self):
2000        """
2001        Get the mapping dictionary from category names to category indexes.
2002
2003        This dictionary can be used to look up which category name corresponds to a particular category index.
2004
2005        Returns:
2006            Dict[str, int], the mappings from category names to category indexes.
2007
2008        Examples:
2009            >>> import mindspore.dataset as ds
2010            >>> # Read image files
2011            >>> image_folder_dataset_dir = "/path/to/image_folder_dataset_directory"
2012            >>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir)
2013            >>> # Check how many classes exist in image folder
2014            >>> class_indexing = dataset.get_class_indexing()
2015        """
2016        if self.children:
2017            return self.children[0].get_class_indexing()
2018        return {}
2019
2020    def reset(self):
2021        """
2022        Reset the dataset for next epoch.
2023
2024        Examples:
2025            >>> import mindspore.dataset as ds
2026            >>> mind_dataset_dir = ["/path/to/mind_dataset_file"]
2027            >>> dataset = ds.MindDataset(dataset_files=mind_dataset_dir)
2028            >>> for _ in range(5):
2029            ...     num_iter = 0
2030            ...     for data in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
2031            ...         num_iter += 1
2032            ...     dataset.reset()
2033        """
2034
2035    def is_shuffled(self):
2036        """Returns True if the dataset or its children is shuffled."""
2037        for input_dataset in self.children:
2038            if input_dataset.is_shuffled():
2039                return True
2040
2041        return False
2042
2043    def is_sharded(self):
2044        """Returns True if the dataset or its children is sharded."""
2045        for input_dataset in self.children:
2046            if input_dataset.is_sharded():
2047                return True
2048
2049        return False
2050
2051    def parse(self, children=None):
2052        raise NotImplementedError("Dataset has to implement parse method.")
2053
2054    def __len__(self):
2055        """
2056        Get the length of dataset.
2057
2058        Returns:
2059            int, the length of dataset.
2060        """
2061        return self.get_dataset_size()
2062
2063    @staticmethod
2064    def _update_data_shard(num_shards, shard_id):
2065        """
2066        Update the shard number and shard id if necessary.
2067        This is normally used in distributed training mode like Parameter Server training.
2068        """
2069        # If this is in distributed execution mode,
2070        # the shard number and shard id might need to be updated according to the process's rank or role.
2071        worker_num = _get_ps_context("worker_num")
2072        server_num = _get_ps_context("server_num")
2073        if _is_role_pserver() and _enable_distributed_mindrt() and (worker_num != server_num):
2074            num_shards = worker_num
2075            shard_id = 0
2076        return num_shards, shard_id
2077
2078    def pre_parse(self, getter_mode):
2079        if getter_mode:
2080            if hasattr(self, "python_multiprocessing"):
2081                self.python_multiprocessing = False
2082            if hasattr(self, "num_parallel_workers"):
2083                self.num_parallel_workers = 1
2084
2085    def post_parse(self, ir_node):
2086        if self.cache:
2087            ir_node = ir_node.set_cache_client(self.cache.cache_client)
2088        if self.num_parallel_workers:
2089            ir_node = ir_node.set_num_workers(self.num_parallel_workers)
2090
2091        return ir_node
2092
2093    def set_init_step(self, init_step):
2094        self._global_step = init_step
2095
2096    def get_init_step(self):
2097        if self._global_step is not None:
2098            return self._global_step
2099        if len(self.children) == 1:
2100            return self.children[0].get_init_step()
2101        # When there are multiple children, we cannot tell from which child to get the initial step,
2102        # so we initialize from the beginning
2103        return 0
2104
2105
2106class VisionBaseDataset(Dataset):
2107    """
2108    Abstract class to represent a vision source dataset which produces content to the data pipeline.
2109    """
2110
2111    def __init__(self, children=None, num_parallel_workers=None, cache=None):
2112        super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache)
2113
2114    def parse(self, children=None):
2115        raise NotImplementedError("Dataset has to implement parse method.")
2116
2117
2118class TextBaseDataset(Dataset):
2119    """
2120    Abstract class to represent a text source dataset which produces content to the data pipeline.
2121    """
2122
2123    def __init__(self, children=None, num_parallel_workers=None, cache=None):
2124        super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache)
2125
2126    def parse(self, children=None):
2127        raise NotImplementedError("Dataset has to implement parse method.")
2128
2129    def build_vocab(self, columns, freq_range, top_k, special_tokens, special_first):
2130        """
2131        Function to create a Vocab from source dataset.
2132        Desired source dataset is a text type dataset.
2133
2134        Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab
2135        which contains top_k most frequent words (if top_k is specified).
2136
2137        Note:
2138            mindspore.dataset.Dataset.build_vocab is deprecated from version 2.0
2139            and will be removed in a future version. Use mindspore.dataset.text.Vocab.from_dataset instead.
2140
2141        Args:
2142            columns(Union[str, list[str]]): Column names to get words from.
2143            freq_range(tuple[int]): A tuple of integers (min_frequency, max_frequency). Words within the frequency
2144                range will be stored.
2145                Naturally 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
2146                can be set to default, which corresponds to 0/total_words separately.
2147            top_k(int): Number of words to be built into vocab. top_k most frequent words are
2148                taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken
2149            special_tokens(list[str]): A list of strings, each one is a special token.
2150            special_first(bool): Whether special_tokens will be prepended/appended to vocab, If special_tokens
2151                is specified and special_first is set to default, special_tokens will be prepended.
2152
2153        Returns:
2154            Vocab, vocab built from the dataset.
2155        """
2156        warnings.warn("mindspore.dataset.Dataset.build_vocab is deprecated from version 2.0 "
2157                      "and will be removed in a future version. "
2158                      "Use mindspore.dataset.text.Vocab.from_dataset instead.", DeprecationWarning)
2159
2160    def build_sentencepiece_vocab(self, columns, vocab_size, character_coverage, model_type, params):
2161        """
2162        Function to create a SentencePieceVocab from source dataset.
2163        Desired source dataset is a text type dataset.
2164
2165        Note:
2166            mindspore.dataset.Dataset.build_sentencepiece_vocab is deprecated from version 2.0
2167            and will be removed in a future version. Use mindspore.dataset.text.SentencePieceVocab.from_dataset instead.
2168
2169        Args:
2170            columns(list[str]): Column names to get words from.
2171            vocab_size(int): Vocabulary size.
2172            character_coverage(float): Percentage of characters covered by the model, must be between
2173                0.98 and 1.0 Good defaults are: 0.9995 for languages with rich character sets like
2174                Japanese or Chinese character sets, and 1.0 for other languages with small character sets
2175                like English or Latin.
2176            model_type(SentencePieceModel): Model type. Choose from unigram (default), bpe, char, or word.
2177                The input sentence must be pretokenized when using word type.
2178            params(dict): Any extra optional parameters of sentencepiece library according to your raw data
2179
2180        Returns:
2181            SentencePieceVocab, vocab built from the dataset.
2182        """
2183        warnings.warn("mindspore.dataset.Dataset.build_sentencepiece_vocab is deprecated from version 2.0 "
2184                      "and will be removed in a future version. "
2185                      "Use mindspore.dataset.text.SentencePieceVocab.from_dataset instead.", DeprecationWarning)
2186
2187    def _build_vocab(self, columns, freq_range, top_k, special_tokens, special_first):
2188        """
2189        Function to create a Vocab from source dataset.
2190        Desired source dataset is a text type dataset.
2191
2192        Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab
2193        which contains top_k most frequent words (if top_k is specified).
2194
2195        Args:
2196            columns(Union[str, list[str]]): Column names to get words from.
2197            freq_range(tuple[int]): A tuple of integers (min_frequency, max_frequency). Words within the frequency
2198                range will be stored.
2199                Naturally 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
2200                can be set to default, which corresponds to 0/total_words separately.
2201            top_k(int): Number of words to be built into vocab. top_k most frequent words are
2202                taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken
2203            special_tokens(list[str]): A list of strings, each one is a special token.
2204            special_first(bool): Whether special_tokens will be prepended/appended to vocab, If special_tokens
2205                is specified and special_first is set to default, special_tokens will be prepended.
2206
2207        Returns:
2208            Vocab, vocab built from the dataset.
2209        """
2210        vocab = cde.Vocab()
2211        columns = replace_none(columns, [])
2212        if not isinstance(columns, list):
2213            columns = [columns]
2214
2215        freq_range = replace_none(freq_range, (0, 9223372036854775807))
2216        if freq_range[0] is None:
2217            freq_range = (0, freq_range[1])
2218        if freq_range[1] is None:
2219            freq_range = (freq_range[0], 9223372036854775807)
2220        special_tokens = replace_none(special_tokens, [])
2221        top_k = replace_none(top_k, 9223372036854775807)
2222
2223        ir_tree, api_tree = self.create_ir_tree()
2224
2225        # vocab node
2226        vocab_node = cde.BuildVocabNode(ir_tree, vocab, columns, freq_range, top_k, special_tokens, special_first)
2227
2228        runtime_context = cde.PythonRuntimeContext()
2229        runtime_context.Init()
2230
2231        # build vocab
2232        consumer = cde.PythonBuildVocabConsumer()
2233        consumer.Init(vocab_node)
2234        runtime_context.AssignConsumer(consumer)
2235
2236        consumer.Start()
2237        del api_tree
2238
2239        return vocab
2240
2241    def _build_sentencepiece_vocab(self, columns, vocab_size, character_coverage, model_type, params):
2242        """
2243        Function to create a SentencePieceVocab from source dataset.
2244        Desired source dataset is a text type dataset.
2245
2246        Args:
2247            columns(list[str]): Column names to get words from.
2248            vocab_size(int): Vocabulary size.
2249            character_coverage(float): Percentage of characters covered by the model, must be between
2250                0.98 and 1.0 Good defaults are: 0.9995 for languages with rich character sets like
2251                Japanese or Chinese character sets, and 1.0 for other languages with small character sets
2252                like English or Latin.
2253            model_type(SentencePieceModel): Model type. Choose from unigram (default), bpe, char, or word.
2254                The input sentence must be pretokenized when using word type.
2255            params(dict): Any extra optional parameters of sentencepiece library according to your raw data
2256
2257        Returns:
2258            SentencePieceVocab, vocab built from the dataset.
2259        """
2260        if not isinstance(model_type, SentencePieceModel):
2261            raise TypeError("Argument model_type with value {0} is not of type SentencePieceModel, but got {1}." \
2262                            .format(model_type, type(model_type)))
2263        model_type = DE_C_INTER_SENTENCEPIECE_MODE[model_type]
2264        vocab = cde.SentencePieceVocab()
2265
2266        ir_tree, api_tree = self.create_ir_tree()
2267
2268        # vocab node
2269        vocab_node = cde.BuildSentenceVocabNode(ir_tree, vocab, columns, vocab_size, character_coverage, model_type,
2270                                                params)
2271
2272        runtime_context = cde.PythonRuntimeContext()
2273        runtime_context.Init()
2274
2275        # build vocab
2276        consumer = cde.PythonBuildVocabConsumer()
2277        consumer.Init(vocab_node)
2278        runtime_context.AssignConsumer(consumer)
2279
2280        consumer.Start()
2281        del api_tree
2282
2283        return vocab
2284
2285
2286class AudioBaseDataset(Dataset):
2287    """
2288    Abstract class to represent a audio source dataset which produces content to the data pipeline.
2289    """
2290
2291    def __init__(self, children=None, num_parallel_workers=None, cache=None):
2292        super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache)
2293
2294    def parse(self, children=None):
2295        raise NotImplementedError("Dataset has to implement parse method.")
2296
2297
2298class UnionBaseDataset(VisionBaseDataset, TextBaseDataset, AudioBaseDataset):
2299    """
2300    Abstract class to represent a union source dataset which produces content to the data pipeline.
2301    """
2302
2303    def __init__(self, children=None, num_parallel_workers=None, cache=None):
2304        super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache)
2305
2306    def parse(self, children=None):
2307        raise NotImplementedError("Dataset has to implement parse method.")
2308
2309
2310class SourceDataset(Dataset):
2311    """
2312    Abstract class to represent a source dataset which produces content to the data pipeline.
2313    """
2314
2315    def __init__(self, num_parallel_workers=None, num_samples=None, shuffle=True, num_shards=None, shard_id=None,
2316                 cache=None):
2317        super().__init__(num_parallel_workers=num_parallel_workers, cache=cache)
2318        self.num_samples = replace_none(num_samples, 0)
2319        self.num_shards = replace_none(num_shards, 1)
2320        self.shard_id = replace_none(shard_id, 0)
2321
2322        if shuffle is not None and not isinstance(shuffle, (bool, Shuffle)):
2323            raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like 'Shuffle.GLOBAL' or "
2324                            "'Shuffle.FILES' or 'Shuffle.INFILE'.")
2325
2326        self.shuffle_flag = 2  # Global shuffle
2327        if not isinstance(shuffle, Shuffle):
2328            if shuffle is None or shuffle:
2329                self.shuffle_flag = 2  # Global shuffle
2330            else:
2331                self.shuffle_flag = 0  # No shuffle
2332        else:
2333            if shuffle == Shuffle.GLOBAL:
2334                self.shuffle_flag = 2  # Global shuffle
2335            elif shuffle == Shuffle.FILES:
2336                self.shuffle_flag = 1  # Files shuffle
2337            elif shuffle == Shuffle.INFILE:
2338                self.shuffle_flag = 3  # Infile shuffle
2339
2340    def parse(self, children=None):
2341        raise NotImplementedError("Dataset has to implement parse method.")
2342
2343    @staticmethod
2344    def _find_files(patterns):
2345        """
2346        Utility function to search for files with the given glob patterns.
2347
2348        Args:
2349            patterns (Union[str, list[str]]): String or list of patterns to be searched.
2350
2351        Returns:
2352            list, list of files.
2353        """
2354
2355        if not isinstance(patterns, list):
2356            patterns = [patterns]
2357
2358        file_list = []
2359        unmatched_patterns = []
2360        for pattern in patterns:
2361            matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)]
2362
2363            if matches:
2364                file_list.extend(matches)
2365            else:
2366                unmatched_patterns.append(pattern)
2367
2368        if unmatched_patterns:
2369            raise ValueError("The following patterns did not match any files: {}.".format(unmatched_patterns))
2370
2371        if file_list:  # not empty
2372            return file_list
2373        raise ValueError("The list of path names matching the patterns is empty.")
2374
2375    def is_shuffled(self):
2376        return self.shuffle_flag > 0
2377
2378    def is_sharded(self):
2379        if self.num_shards is not None:
2380            return self.num_shards > 1
2381        return False
2382
2383
2384class MappableDataset(SourceDataset):
2385    """
2386    Abstract class to represent a source dataset which supports use of samplers.
2387    """
2388
2389    def parse(self, children=None):
2390        raise NotImplementedError("Dataset has to implement parse method.")
2391
2392    def __init__(self, num_parallel_workers=None, sampler=None, num_samples=None, shuffle=None, num_shards=None,
2393                 shard_id=None, cache=None):
2394        num_shards, shard_id = self._update_data_shard(num_shards, shard_id)
2395        super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
2396                         num_shards=num_shards, shard_id=shard_id, cache=cache)
2397        self.shuffle_flag = replace_none(shuffle, True)
2398        self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
2399
2400    def add_sampler(self, new_sampler):
2401        """
2402        Add a child sampler for the current dataset.
2403
2404        Args:
2405            new_sampler (Sampler): The child sampler to be added.
2406
2407        Examples:
2408            >>> import mindspore.dataset as ds
2409            >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
2410            >>>
2411            >>> new_sampler = ds.DistributedSampler(10, 2)
2412            >>> dataset.add_sampler(new_sampler)
2413        """
2414        # Note: By adding a sampler, the sampled IDs will flow to the new_sampler
2415        # after first passing through the current samplers attached to this dataset.
2416        self.dataset_size = None
2417        new_sampler.add_child(self.sampler)
2418        self.sampler = new_sampler
2419
2420    def use_sampler(self, new_sampler):
2421        """
2422        Replace the last child sampler of the current dataset, remaining the parent sampler unchanged.
2423
2424        Args:
2425            new_sampler (Sampler): The new sampler to replace with.
2426
2427        Examples:
2428            >>> import mindspore.dataset as ds
2429            >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
2430            >>>
2431            >>> # use a DistributedSampler instead
2432            >>> new_sampler = ds.DistributedSampler(10, 2)
2433            >>> dataset.use_sampler(new_sampler)
2434        """
2435        if new_sampler is None:
2436            raise TypeError("Input sampler can not be None.")
2437        if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
2438            raise TypeError("Input sampler is not an instance of a sampler.")
2439        self.dataset_size = None
2440
2441        self.sampler = self.sampler.child_sampler
2442        self.add_sampler(new_sampler)
2443
2444    def is_shuffled(self):
2445        return self.sampler.is_shuffled()
2446
2447    def is_sharded(self):
2448        return self.sampler.is_sharded()
2449
2450    @check_split
2451    def split(self, sizes, randomize=True):
2452        """
2453        Split the dataset into smaller, non-overlapping datasets.
2454
2455        Args:
2456            sizes (Union[list[int], list[float]]): If a list of integers [s1, s2, …, sn] is
2457                provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
2458                respectively. If the sum of all sizes does not equal the original dataset size, an
2459                error will occur.
2460                If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1
2461                and must sum to 1, otherwise an error will occur. The dataset will be split into n
2462                Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the
2463                original dataset.
2464                If after rounding:
2465
2466                - Any size equals 0, an error will occur.
2467                - The sum of split sizes < K, the difference will be added to the first split.
2468                - The sum of split sizes > K, the difference will be removed from the first large
2469                  enough split such that it will have at least 1 row after removing the difference.
2470
2471            randomize (bool, optional): Determines whether or not to split the data randomly. Default: ``True``.
2472                If ``True``, the data will be randomly split. Otherwise, each split will be created with
2473                consecutive rows from the dataset.
2474
2475        Note:
2476            1. There is an optimized split function, which will be called automatically when the dataset
2477               that calls this function is a MappableDataset.
2478            2. Dataset should not be sharded if split is going to be called. Instead, create a
2479               :class:`mindspore.dataset.DistributedSampler` and specify a split to shard after splitting.
2480               If the dataset is sharded after a split, it is strongly recommended setting the same
2481               seed in each instance of execution, otherwise each shard may not be part of the same
2482               split (see Examples).
2483            3. It is strongly recommended to not shuffle the dataset, but set `randomize` to ``True`` instead.
2484               Shuffling the dataset may not be deterministic, which means the data in each split
2485               will be different in each epoch. Furthermore, if sharding occurs after split, each
2486               shard may not be part of the same split.
2487
2488        Returns:
2489            Tuple[Dataset], a tuple of new datasets split from the original one.
2490
2491        Raises:
2492            RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
2493            RuntimeError: If `sizes` is list of integers and sum of all elements in sizes does not
2494                equal the dataset size.
2495            RuntimeError: If `sizes` is list of float and there is a split with size 0 after calculations.
2496            RuntimeError: If the dataset is sharded prior to calling split.
2497            ValueError: If `sizes` is list of float and not all floats are between 0 and 1, or if the
2498                floats don't sum to 1.
2499
2500        Examples:
2501            >>> import mindspore.dataset as ds
2502            >>> # Since many datasets have shuffle on by default, set shuffle to False if split will be called!
2503            >>> image_folder_dataset_dir = "/path/to/image_folder_dataset_directory"
2504            >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, shuffle=False)
2505            >>>
2506            >>> # Set the seed, and tell split to use this seed when randomizing.
2507            >>> # This is needed because sharding will be done later
2508            >>> ds.config.set_seed(58)
2509            >>> train_dataset, test_dataset = dataset.split([0.9, 0.1])
2510            >>>
2511            >>> # To shard the train dataset, use a DistributedSampler
2512            >>> train_sampler = ds.DistributedSampler(10, 2)
2513            >>> train_dataset.use_sampler(train_sampler)
2514        """
2515        if self.is_shuffled():
2516            logger.warning("Dataset is shuffled before split.")
2517
2518        if self.is_sharded():
2519            raise RuntimeError("Dataset should not be sharded before split.")
2520
2521        absolute_sizes = self._get_absolute_split_sizes(sizes)
2522        splits = []
2523        current_split_start_index = 0
2524        for size in absolute_sizes:
2525            ds = copy.deepcopy(self)
2526            ds.dataset_size = None
2527            if randomize:
2528                # want to shuffle the same way every epoch before split, we are assuming
2529                # that the user will call set_seed
2530                random_sampler = samplers.RandomSampler()
2531                random_sampler.reshuffle_each_epoch = False
2532                ds.add_sampler(random_sampler)
2533
2534            subset_sampler = samplers.SequentialSampler(current_split_start_index, size)
2535            ds.add_sampler(subset_sampler)
2536
2537            # add sequential sampler, so that if user calls use_sampler, we will
2538            # get rid of the sequential sampler instead of something we need
2539            ds.add_sampler(samplers.SequentialSampler())
2540
2541            splits.append(ds)
2542
2543            current_split_start_index += size
2544
2545        return tuple(splits)
2546
2547
2548class BucketBatchByLengthDataset(UnionBaseDataset):
2549    """
2550    The result of applying BucketBatchByLength operation to the input dataset.
2551    """
2552
2553    def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes, element_length_function,
2554                 pad_info, pad_to_bucket_boundary, drop_remainder):
2555        super().__init__(children=input_dataset)
2556
2557        self.column_names = to_list(column_names)
2558        self.bucket_boundaries = replace_none(bucket_boundaries, [])
2559        self.bucket_batch_sizes = replace_none(bucket_batch_sizes, [])
2560        self.element_length_function = element_length_function
2561        self.pad_info = replace_none(pad_info, {})
2562        self.pad_to_bucket_boundary = replace_none(pad_to_bucket_boundary, False)
2563        self.drop_remainder = replace_none(drop_remainder, False)
2564
2565    def parse(self, children=None):
2566        return cde.BucketBatchByLengthNode(children[0], self.column_names, self.bucket_boundaries,
2567                                           self.bucket_batch_sizes, self.element_length_function, self.pad_info,
2568                                           self.pad_to_bucket_boundary, self.drop_remainder)
2569
2570
2571def _check_shm_usage(num_worker, queue_size, in_rowsize, out_rowsize):
2572    """
2573    Check sufficient shared memory is available for shared memory queues
2574    when training in parallel mode.
2575    """
2576    threshold_ratio = 0.8
2577    # Verify available size only when using static shared memory on Linux
2578    if platform.system().lower() not in {"windows", "darwin"} and in_rowsize != -1 and out_rowsize != -1:
2579        device_num = _get_device_num()
2580        # In the cluster, _get_device_num indicates the number of the entire cluster. The maximum number of cards
2581        # on the ascend server is 8.
2582        if device_num > 1:
2583            device_num = min(device_num, 8)
2584        shm_estimate_usage = device_num * num_worker * \
2585                             (queue_size + 2) * (in_rowsize + out_rowsize) * 1024 * 1024
2586        try:
2587            shm_available = psutil.disk_usage('/dev/shm').free
2588            if shm_estimate_usage >= threshold_ratio * shm_available:
2589                raise RuntimeError(
2590                    "Insufficient shared memory available. Required: {}, Available: {}. "
2591                    "The required memory can't exceed 80% of the available shared memory, "
2592                    "it's recommended to reduce memory usage by following methods:\n"
2593                    "1. reduce value of parameter max_rowsize or num_parallel_workers.\n"
2594                    "2. reduce prefetch size by set_prefetch_size().\n"
2595                    "3. disable shared memory by set_enable_shared_mem().".format(shm_estimate_usage, shm_available))
2596        except FileNotFoundError:
2597            raise RuntimeError("Expected /dev/shm to exist.")
2598
2599
2600class BatchDataset(UnionBaseDataset):
2601    """
2602    The result of applying Batch operation to the input dataset.
2603
2604    Args:
2605        input_dataset (Dataset): Input Dataset to be batched.
2606        batch_size (Union[int, function]): The number of rows each batch is created with. An
2607            int or callable which takes exactly 1 parameter, BatchInfo.
2608        drop_remainder (bool, optional): Determines whether or not to drop the last
2609            possibly incomplete batch. Default: ``False``. If True, and if there are less
2610            than batch_size rows available to make the last batch, then those rows will
2611            be dropped and not propagated to the child node.
2612        num_parallel_workers (int, optional): Number of workers to process the dataset in parallel. Default: ``None``.
2613        per_batch_map (callable, optional): Per batch map callable. A callable which takes
2614            (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch of
2615            Tensors on a given column. The number of lists should match with number of entries in input_columns. The
2616            last parameter of the callable must always be a BatchInfo object.
2617        input_columns (Union[str, list[str]], optional): List of names of the input columns. The size of the list must
2618            match with signature of per_batch_map callable.
2619        output_columns (Union[str, list[str]], optional): List of names assigned to the columns outputted by
2620            the last operation. This parameter is mandatory if len(input_columns) !=
2621            len(output_columns). The size of this list must match the number of output
2622            columns of the last operation. Default: ``None``, output columns will have the same
2623            name as the input columns, i.e., the columns will be replaced.
2624        max_rowsize(Union[int, list[int]], optional): Maximum size of row in MB that is used for shared memory
2625            allocation to copy data between processes, the total occupied shared memory will increase as
2626            ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set to -1,
2627            shared memory will be dynamically allocated with the actual size of data. This is only used if
2628            ``python_multiprocessing`` is set to True. If it is an int value, it represents
2629            ``input_columns`` and ``output_columns`` use this value as the unit to create shared memory.
2630            If it is a list, the first element represents the ``input_columns`` use this value as the unit to
2631            create shared memory, and the second element represents ``output_columns`` use this value as the unit
2632            to create shared memory. Default: 16.
2633
2634    """
2635
2636    def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
2637                 input_columns=None, output_columns=None, python_multiprocessing=False, max_rowsize=16):
2638        super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers)
2639
2640        if BatchDataset._is_ancestor_of_repeat(input_dataset):
2641            logger.warning("Repeat is located before batch, data from two epochs can be batched together.")
2642
2643        BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
2644
2645        # if batch_size is callable, set batch_size to 1 and batch_size_func to that callable function
2646        self.batch_size = batch_size if not callable(batch_size) else 1
2647        self.batch_size_func = None if not callable(batch_size) else batch_size
2648
2649        self.drop_remainder = replace_none(drop_remainder, False)
2650
2651        self.per_batch_map = per_batch_map
2652
2653        self.input_columns = to_list(input_columns)
2654        self.output_columns = to_list(output_columns)
2655
2656        self.python_multiprocessing = python_multiprocessing
2657        self.process_pool = None
2658        if isinstance(max_rowsize, int):
2659            self.max_rowsize = [max_rowsize * self.batch_size] * 2 if max_rowsize != -1 else [max_rowsize, max_rowsize]
2660        else:
2661            self.max_rowsize = [max_rowsize[0] * self.batch_size, max_rowsize[1] * self.batch_size]
2662
2663    def __del__(self):
2664        if hasattr(self, "process_pool") and self.process_pool is not None:
2665            self.process_pool.terminate()
2666            del self.process_pool
2667
2668    def parse(self, children=None):
2669        return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, False, self.input_columns,
2670                             self.output_columns, self.batch_size_func, self.per_batch_map, {},
2671                             self.process_pool)
2672
2673    @staticmethod
2674    def _is_ancestor_of_repeat(dataset):
2675        """
2676        Utility function to find the case where repeat is used before batch.
2677
2678        Args:
2679             dataset (Dataset): Dataset to be checked.
2680
2681        Returns:
2682            bool, whether repeat is used before batch.
2683        """
2684        if isinstance(dataset, RepeatDataset):
2685            return True
2686        flag = False
2687        for input_dataset in dataset.children:
2688            flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset)
2689        return flag
2690
2691    @staticmethod
2692    def _update_batch_size_for_syncwait(dataset, batch_size):
2693        """
2694        Utility function to notify batch size to sync_wait.
2695
2696        Args:
2697             dataset (Dataset): Dataset to be checked.
2698             batch_size (int): batch size to notify.
2699        """
2700        if isinstance(dataset, SyncWaitDataset):
2701            dataset.update_sync_batch_size(batch_size)
2702        for input_dataset in dataset.children:
2703            BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
2704
2705    def __deepcopy__(self, memodict):
2706        return self.__safe_deepcopy__(memodict, exclude=("per_batch_map", "batch_size_func", "__transfer_dataset__"))
2707
2708    # Iterator bootstrap will be called on iterator construction.
2709    # A deep copy of Dataset object is created prior of iterator_bootstrap.
2710    # This method will create per iterator process pool and bind pyfunc execution to the pool.
2711    def iterator_bootstrap(self):
2712        """
2713        Per iterator bootstrap callback.
2714        """
2715        if self.python_multiprocessing and platform.system().lower() == 'windows':
2716            logger.warning("Python multiprocessing is not supported on Windows platform.")
2717        if self.python_multiprocessing and get_debug_mode():
2718            logger.warning("Python multiprocessing is not supported in debug mode."
2719                           " Ignoring Python multiprocessing for batch operation.")
2720            self.python_multiprocessing = False
2721        if self.python_multiprocessing and platform.system().lower() != 'windows':
2722            if self.per_batch_map is None:
2723                logger.warning("per_batch_map is None so python_multiprocessing is ignored for batch.")
2724                return
2725
2726            # If user didn't specify num_parallel_workers, set it to default
2727            if self.num_parallel_workers is None:
2728                self.num_parallel_workers = get_num_parallel_workers()
2729
2730            self.process_pool = _PythonMultiprocessing(str(self), self.num_parallel_workers, [self.per_batch_map],
2731                                                       self.max_rowsize)
2732            # Wrap per_batch_map into _PythonCallable
2733            self.per_batch_map = _PythonCallable(self.per_batch_map, 0, self.process_pool)
2734        else:
2735            if self.per_batch_map is not None:
2736                self.per_batch_map = FuncWrapper(self.per_batch_map)
2737
2738
2739class BatchInfo(cde.CBatchInfo):
2740    """
2741    This class helps to get dataset information dynamically when the input of `batch_size` or `per_batch_map`
2742    in `batch` operation is a callable object.
2743    """
2744
2745    def get_batch_num(self):
2746        """
2747        Return the batch number being processed in current epoch, start from 0.
2748
2749        Examples:
2750            >>> # Create a dataset where its batch size is dynamic
2751            >>> # Define a callable batch size function and let batch size increase 1 each time.
2752            >>> import mindspore.dataset as ds
2753            >>> from mindspore.dataset import BatchInfo
2754            >>>
2755            >>> dataset = ds.GeneratorDataset([i for i in range(3)], "column1", shuffle=False)
2756            >>> def add_one(BatchInfo):
2757            ...     return BatchInfo.get_batch_num() + 1
2758            >>> dataset = dataset.batch(batch_size=add_one)
2759            >>> print(list(dataset))
2760            [[Tensor(shape=[1], dtype=Int64, value= [0])], [Tensor(shape=[2], dtype=Int64, value= [1, 2])]]
2761        """
2762        return
2763
2764    def get_epoch_num(self):
2765        """
2766        Return the epoch number, start from 0.
2767
2768        Examples:
2769            >>> # Create a dataset where its batch size is dynamic
2770            >>> # Define a callable batch size function and let batch size increase 1 each epoch.
2771            >>> import mindspore.dataset as ds
2772            >>> from mindspore.dataset import BatchInfo
2773            >>>
2774            >>> dataset = ds.GeneratorDataset([i for i in range(4)], "column1", shuffle=False)
2775            >>> def add_one_by_epoch(BatchInfo):
2776            ...     return BatchInfo.get_epoch_num() + 1
2777            >>> dataset = dataset.batch(batch_size=add_one_by_epoch)
2778            >>>
2779            >>> result = []
2780            >>> epoch = 2
2781            >>> iterator = dataset.create_tuple_iterator(num_epochs=epoch)
2782            >>> for i in range(epoch):
2783            ...    result.extend(list(iterator))
2784            >>> # result:
2785            >>> # [[Tensor(shape=[1], dtype=Int64, value= [0])], [Tensor(shape=[1], dtype=Int64, value= [1])],
2786            >>> #  [Tensor(shape=[1], dtype=Int64, value= [2])], [Tensor(shape=[1], dtype=Int64, value= [3])],
2787            >>> #  [Tensor(shape=[2], dtype=Int64, value= [0, 1])], [Tensor(shape=[2], dtype=Int64, value= [2, 3])]]
2788        """
2789        return
2790
2791
2792class BlockReleasePair:
2793    """
2794    The blocking condition class used by SyncWaitDataset.
2795
2796    Args:
2797        init_release_rows (int): Number of lines to allow through the pipeline.
2798        callback (function): The callback function that will be called when release is called. Default: ``None``.
2799    """
2800
2801    def __init__(self, init_release_rows, callback=None):
2802        if isinstance(init_release_rows, int) and init_release_rows <= 0:
2803            raise ValueError("release_rows need to be greater than 0.")
2804        self.row_count = -init_release_rows
2805        self.cv = threading.Condition()
2806        self.callback = callback
2807        self.default_rows = init_release_rows
2808        self.disable = False
2809
2810    def __deepcopy__(self, memodict):
2811        return self
2812
2813    def reset(self):
2814        with self.cv:
2815            self.row_count = -self.default_rows
2816            self.cv.notify_all()
2817
2818    def update_batched_size(self, batch_size):
2819        # sanity check
2820        if isinstance(batch_size, int) and batch_size <= 0:
2821            raise ValueError("batch_size need to be greater than 0.")
2822
2823        # should only use before the pipeline creates
2824        self.row_count *= batch_size
2825        self.default_rows *= batch_size
2826
2827    def block_func(self):
2828        """
2829        Function for handing blocking condition.
2830
2831        Returns:
2832            bool, True.
2833        """
2834        with self.cv:
2835            # if disable is true, the always evaluate to true
2836            not_time_out = self.cv.wait_for(lambda: (self.row_count < 0 or self.disable),
2837                                            timeout=get_callback_timeout())
2838            # time_out will be False if time out occurs
2839            if not not_time_out:
2840                logger.warning("Timeout happened in sync_wait, maybe dataset.sync_update(condition=...) "
2841                               "is not added after dataset.create_dict_iterator(...), now disabling lock.")
2842                self.disable = True
2843            self.row_count += 1
2844        return True
2845
2846    def release_func(self, pass_rows=None, data=None):
2847        with self.cv:
2848            if pass_rows is None:
2849                pass_rows = self.default_rows
2850            self.row_count -= pass_rows
2851            if self.callback is not None:
2852                self.callback(data)
2853            self.cv.notify_all()
2854
2855    def disable_lock(self):
2856        with self.cv:
2857            self.disable = True
2858            self.cv.notify_all()
2859
2860
2861class PaddedBatchDataset(UnionBaseDataset):
2862    """
2863    The result of applying Batch operation to the input dataset.
2864
2865    Args:
2866        input_dataset (Dataset): Input Dataset to be batched.
2867        batch_size (Union[int, function]): The number of rows each batch is created with. An
2868            int or callable which takes exactly 1 parameter, BatchInfo.
2869        drop_remainder (bool, optional): Determines whether or not to drop the last
2870            possibly incomplete batch. Default: ``False``. If True, and if there are less
2871            than batch_size rows available to make the last batch, then those rows will
2872            be dropped and not propagated to the child node.
2873        num_parallel_workers (int, optional): Number of workers to process the dataset in parallel. Default: ``None``.
2874        pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
2875            will pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0.
2876    """
2877
2878    def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None, pad_info=None):
2879        super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers)
2880
2881        if PaddedBatchDataset._is_ancestor_of_repeat(input_dataset):
2882            logger.warning("Repeat is located before padded_batch, data from two epochs can be batched together.")
2883
2884        PaddedBatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
2885
2886        # if batch_size is callable, set batch_size to 1 and batch_size_func to that callable function
2887        self.batch_size = batch_size if not callable(batch_size) else 1
2888        self.batch_size_func = None if not callable(batch_size) else batch_size
2889
2890        self.drop_remainder = replace_none(drop_remainder, False)
2891
2892        self.pad = bool(pad_info is not None)
2893        self.pad_info = replace_none(pad_info, dict())
2894
2895    def parse(self, children=None):
2896        return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, self.pad, [],
2897                             [], self.batch_size_func, None, self.pad_info, None)
2898
2899    @staticmethod
2900    def _is_ancestor_of_repeat(dataset):
2901        """
2902        Utility function to find the case where repeat is used before batch.
2903
2904        Args:
2905             dataset (Dataset): Dataset to be checked.
2906
2907        Returns:
2908            bool, whether repeat is used before batch.
2909        """
2910        if isinstance(dataset, RepeatDataset):
2911            return True
2912        flag = False
2913        for input_dataset in dataset.children:
2914            flag = flag | PaddedBatchDataset._is_ancestor_of_repeat(input_dataset)
2915        return flag
2916
2917    @staticmethod
2918    def _update_batch_size_for_syncwait(dataset, batch_size):
2919        """
2920        Utility function to notify batch size to sync_wait.
2921
2922        Args:
2923             dataset (Dataset): Dataset to be checked.
2924             batch_size (int): batch size to notify.
2925        """
2926        if isinstance(dataset, SyncWaitDataset):
2927            dataset.update_sync_batch_size(batch_size)
2928        for input_dataset in dataset.children:
2929            PaddedBatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
2930
2931    def __deepcopy__(self, memodict):
2932        return self.__safe_deepcopy__(memodict, exclude=("batch_size_func", "__transfer_dataset__"))
2933
2934
2935class SyncWaitDataset(UnionBaseDataset):
2936    """
2937    The result of adding a blocking condition to the input Dataset.
2938
2939    Args:
2940        input_dataset (Dataset): Input dataset to apply flow control.
2941        num_batch (int): Number of batches without blocking at the start of each epoch.
2942        condition_name (str): Condition name that is used to toggle sending next row.
2943        callback (function): Callback function that will be invoked when sync_update is called. Default: ``None``.
2944
2945    Raises:
2946        RuntimeError: If condition name already exists.
2947    """
2948
2949    def __init__(self, input_dataset, condition_name, num_batch, callback=None):
2950        super().__init__(children=input_dataset)
2951
2952        # set to the default value, waiting for the batch to update it
2953        self._condition_name = condition_name
2954        if isinstance(num_batch, int) and num_batch <= 0:
2955            raise ValueError("num_batch need to be greater than 0.")
2956
2957        self._pair = BlockReleasePair(num_batch, callback)
2958        if self._condition_name in self.children[0].get_sync_notifiers():
2959            raise RuntimeError("Condition name is already in use.")
2960        logger.info("Please remember to add dataset.sync_update(condition=%s), otherwise hanging will result. "
2961                    "If dataset.sync_update(condition=%s) has already been added, you can ignore the info.",
2962                    condition_name, condition_name)
2963
2964    def parse(self, children=None):
2965        return cde.SyncWaitNode(children[0], self._condition_name, self._pair.block_func)
2966
2967    def get_sync_notifiers(self):
2968        return {**self.children[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}}
2969
2970    def is_sync(self):
2971        return True
2972
2973    def update_sync_batch_size(self, batch_size):
2974        if isinstance(batch_size, int) and batch_size <= 0:
2975            raise ValueError("num_batch need to be greater than 0.")
2976        self._pair.update_batched_size(batch_size)
2977
2978    def disable_sync(self):
2979        logger.info("Disabling Sync")
2980        self._pair.disable_lock()
2981
2982    @staticmethod
2983    def _is_ancestor_of_batch(dataset):
2984        """
2985        Utility function to find the case where sync_wait is used before batch.
2986
2987        Args:
2988             dataset (Dataset): Dataset to be checked.
2989
2990        Returns:
2991            bool, whether sync_wait is used before batch.
2992        """
2993        if isinstance(dataset, (BatchDataset, PaddedBatchDataset)):
2994            return True
2995        flag = False
2996        for input_dataset in dataset.children:
2997            flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
2998        return flag
2999
3000    def iterator_bootstrap(self):
3001        self._pair.reset()
3002
3003
3004class ShuffleDataset(UnionBaseDataset):
3005    """
3006    The result of applying Shuffle operation to the input Dataset.
3007
3008    Args:
3009        input_dataset (Dataset): Input Dataset to be shuffled.
3010        buffer_size (int): Size of the buffer.
3011
3012    Raises:
3013        RuntimeError: If exist sync operations before shuffle.
3014    """
3015
3016    def __init__(self, input_dataset, buffer_size):
3017        super().__init__(children=input_dataset)
3018        self.buffer_size = buffer_size
3019        self.reshuffle_each_epoch = True
3020
3021        if self.is_sync():
3022            raise RuntimeError("No shuffle after sync operators.")
3023
3024    def parse(self, children=None):
3025        return cde.ShuffleNode(children[0], self.buffer_size, self.reshuffle_each_epoch)
3026
3027    def is_shuffled(self):
3028        return True
3029
3030
3031# Pyfunc collection for multiprocess pyfunc
3032# This global variable will only be used within subprocesses
3033_OP_NAME = dict()
3034_OP_PROCESS = dict()
3035
3036
3037# PythonCallable wrapper for multiprocess pyfunc
3038class _PythonCallable:
3039    """
3040    Internal Python function wrapper for multiprocessing pyfunc.
3041    """
3042
3043    def __init__(self, py_callable, idx, pool=None):
3044        # Original Python callable from user.
3045        self.py_callable = py_callable
3046        # Process pool created for current iterator.
3047        self.pool = pool
3048        # Python callable index
3049        self.idx = idx
3050
3051    def __call__(self, *args):
3052        result = None
3053        get_data_from_worker_process = False
3054        while get_data_from_worker_process is False:
3055            if self.pool.is_running() and check_iterator_cleanup() is False:
3056                try:
3057                    result = self.pool.execute(self.idx, *args)
3058                except multiprocessing.TimeoutError:
3059                    continue
3060                get_data_from_worker_process = True
3061            else:
3062                # worker process is stopped
3063                logger.info("The worker process of map operation is stopped. "
3064                            "So return None to main thread and break the main thread.")
3065                return None
3066        # got value from worker process
3067        if not isinstance(result, tuple) and get_data_from_worker_process is True:
3068            result = (result,)
3069        return result
3070
3071    def to_json(self):
3072        return self.py_callable.to_json()
3073
3074
3075# used when python_multiprocessing=True in map
3076class Pipe:
3077    """
3078    Class to handle communication between the master process and the worker processes.
3079    """
3080
3081    def __init__(self, warning_ctl, shared_memory=False, max_rowsize=16):
3082        self.shared_memory = shared_memory
3083        self.eof = multiprocessing.Event()
3084        if self.shared_memory:
3085            self.in_queue = _SharedQueue(1, warning_ctl, max_rowsize=max_rowsize[0])
3086            self.res_queue = _SharedQueue(1, warning_ctl, max_rowsize=max_rowsize[1])
3087        else:
3088            self.in_queue = _Queue(1)
3089            self.res_queue = _Queue(1)
3090        self.in_queue.cancel_join_thread()  # Ensure that the process does not hung when exiting
3091
3092    def master_send(self, func_index, data):
3093        self.in_queue.put_nowait((func_index, *data))
3094
3095    def master_receive(self):
3096        if self.eof is None:
3097            raise RuntimeError("EOF is none when get data from worker.")
3098        if self.eof.is_set():
3099            return None
3100        return self.res_queue.get(timeout=1)
3101
3102    def master_close(self):
3103        self.eof.set()
3104        self.send_finish_signal_to_worker()
3105        self.send_finish_signal()
3106
3107    def send_finish_signal(self):
3108        self.worker_send(None)
3109
3110    def send_finish_signal_to_worker(self):
3111        self.master_send(0, "QUIT")
3112
3113    def worker_send(self, data):
3114        self.res_queue.put_until(data, timeout=1, exit_signal=self.eof)
3115
3116    def worker_receive(self):
3117        result = self.in_queue.get_until(timeout=1, exit_signal=self.eof)
3118        if result is None:
3119            return result
3120        if len(result) == 1:
3121            raise RuntimeError(f"Corrupted data. Worker received {len(result)} elements, it should be more than 1.")
3122        func_index, *data = result
3123        return func_index, tuple(data)
3124
3125
3126def _main_process_already_exit():
3127    """
3128    Judge whether main process already exit.
3129    """
3130    ppid = os.getppid()
3131
3132    if (platform.system().lower() != 'windows' and
3133            not _PythonMultiprocessing.is_process_alive(ppid)):
3134        return True
3135    return False
3136
3137
3138def _worker_loop(operations, pipe, worker_id):
3139    """
3140    Multiprocess worker process loop.
3141    """
3142    # Ensure that the process does not hung when exiting
3143    pipe.res_queue.cancel_join_thread()
3144
3145    def _ignore_sigint():
3146        """
3147        We need to ignore sigint signal here so subprocesses can exit normally and clear.
3148        """
3149        signal.signal(signal.SIGINT, signal.SIG_IGN)
3150
3151    # If the default random seed has not been changed, there is no need to fix the randomness.
3152    # Otherwise, set the random seed for each child process to "base_seed + worker_id" to ensure
3153    # that the random results of each process are different.
3154    if get_seed() != 5489:
3155        set_seed(get_seed() + worker_id)
3156    while not _main_process_already_exit():
3157        _ignore_sigint()
3158
3159        result = pipe.worker_receive()
3160        if result is None:
3161            return
3162        (idx, input_tensors) = result
3163        if input_tensors == "QUIT":
3164            break
3165        try:
3166            output_tensors = operations[idx](*input_tensors)
3167
3168            pipe.worker_send(output_tensors)
3169        except Exception:
3170            pipe.worker_send(ExceptionHandler(where="in map(or batch) worker and execute Python function"))
3171            # Do not return
3172
3173    # release the queue when stop the worker by master
3174    del pipe.in_queue
3175    del pipe.res_queue
3176
3177
3178def worker_target(operations, worker_id):
3179    return lambda pipe: _worker_loop(operations, pipe, worker_id)
3180
3181
3182class _MPWorker(multiprocessing.Process):
3183    """
3184    Worker process for multiprocessing.
3185    """
3186
3187    def __init__(self, operations, warning_ctl, max_rowsize=16, worker_id=0):
3188        shared_memory = get_enable_shared_mem()
3189        self.pipe = Pipe(warning_ctl, shared_memory=shared_memory, max_rowsize=max_rowsize)
3190        self.check_interval = get_multiprocessing_timeout_interval()
3191        super().__init__(target=worker_target(operations, worker_id), name="MapWorker" + str(worker_id),
3192                         args=(self.pipe,), daemon=True)
3193
3194    def execute(self, idx, *args):
3195        """Acquiring data from a worker in an infinite loop"""
3196        self.pipe.master_send(idx, args)
3197        time_s = time.time()
3198        wait_count = 1
3199        while True:
3200            cost_time = time.time() - time_s
3201            if cost_time / self.check_interval >= wait_count:
3202                wait_count += 1
3203                logger.warning("It has been waiting for " + "%.3f" % cost_time + "s because the sub-process "
3204                               "worker of the map operation is hanging. "
3205                               "Check whether the user defined data transform is too slow or the "
3206                               "output data is too large. You can also set the timeout interval by "
3207                               "ds.config.set_multiprocessing_timeout_interval to adjust the output frequency "
3208                               "of this log.")
3209                pid = self.pid
3210                logger.warning("Map worker subprocess ID {} is stuck.".format(pid))
3211                install_status, _ = subprocess.getstatusoutput("py-spy --version")
3212                if install_status == 0:
3213                    stack = subprocess.getoutput("py-spy dump -p {} -l".format(pid))
3214                    logger.warning("Map worker subprocess stack:\n{}".format(stack))
3215                else:
3216                    logger.warning("Please `pip install py-spy` to get the stacks of the stuck process.")
3217            try:
3218                res = self.pipe.master_receive()
3219                # Because there is no need to copy when creating Tensors in the C++layer, it reduces the time
3220                # from np.ndarray to C++Tensor creation. However, when using shared memory in multiple processes,
3221                # the address of the shared memory will always be passed to subsequent nodes in the dataset pipeline,
3222                # and the shared memory will also be written by the current node, causing dirty data to be accessed
3223                # by subsequent nodes in the pipeline. So make a memory copy here to solve the problem of
3224                # shared memory being contaminated.
3225                if get_enable_shared_mem():
3226                    res = copy.deepcopy(res)
3227            except queue.Empty:
3228                continue
3229            if res is None:
3230                # receive finish signal
3231                return None
3232            if isinstance(res, ExceptionHandler):
3233                res.reraise()
3234            return res
3235
3236    def close(self):
3237        try:
3238            if self.is_alive():
3239                # release the eager executor which is used by current process
3240                transforms.transforms.clean_unused_executors()
3241
3242                logger.info(f"Closing worker with PID: {self.pid}")
3243                self.pipe.master_close()
3244                # del the handle which hold by master
3245                del self.pipe.in_queue
3246                del self.pipe.res_queue
3247                super().terminate()
3248                super().join()
3249                super().close()
3250
3251        except ValueError:
3252            # Process has been closed already
3253            return
3254        return
3255
3256    def is_alive(self):
3257        try:
3258            return super().is_alive()
3259        except ValueError:
3260            return False
3261
3262
3263class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
3264    """
3265    A wrapper to multiprocessing.pool that performs cleanup and ensure proper termination of forked processes.
3266    """
3267
3268    class _ExceptHookHandler:
3269        """
3270        Internal class ExceptionHandler
3271        """
3272
3273        def __init__(self):
3274            self.origin_hook = sys.excepthook
3275            sys.excepthook = self.__handler_exception
3276
3277        @staticmethod
3278        def mp_pool_exit_preprocess():
3279            if check_iterator_cleanup() is False:
3280                # Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async
3281                # applied to the multiprocessing task to prevent multiprocessing from hang when exiting
3282                _set_iterator_cleanup()
3283                time.sleep(3)
3284
3285        def __handler_exception(self, ex_type, value, tb):
3286            self.origin_hook(ex_type, value, tb)
3287            self.mp_pool_exit_preprocess()
3288
3289    def __init__(self, op_name, num_parallel_workers, operations, max_rowsize=16):
3290        super(_PythonMultiprocessing, self).__init__()
3291        self.op_name = op_name
3292        self.num_parallel_workers = num_parallel_workers
3293        self.operations = operations
3294        self.max_rowsize = max_rowsize
3295
3296        self.workers = None
3297        self.pids = None
3298        self.op_id = -1
3299
3300        self.queues_map = {}
3301        self.next_queue = 0
3302
3303        self.eot = None
3304        self.watch_dog = None
3305        self.ppid = os.getpid()
3306        self.hook = None
3307        self.warning_ctl = None
3308        # cache thread (get_ident()) to worker_id mapping in Python layer
3309        self.python_threads_to_workers = {}
3310        self.eof = None
3311
3312    def __del__(self):
3313        try:
3314            self.terminate()
3315        except TypeError:
3316            pass
3317
3318    # This wait function is for cleaning zombie subprocesses
3319    @staticmethod
3320    def wait_pid():
3321        """
3322        This function is used by the main process to release subprocess resources.
3323        """
3324        try:
3325            while True:
3326                child_pid, _ = os.waitpid(-1, os.WNOHANG)
3327                if child_pid == 0:
3328                    break
3329        except OSError:
3330            # waitpid may be failed for some reasons so we ignore this error
3331            pass
3332
3333    # Dataset need watch_dog thread to monitoring fork multi-processing,
3334    # and thread can't be a member function otherwise python won't collect and release resources.
3335    @staticmethod
3336    def _watch_dog(eot, workers):
3337        """
3338        This thread is for monitoring subprocesses forked by GeneratorDataset/map/batch
3339        """
3340        if not isinstance(workers, list):
3341            raise TypeError("[Internal Error] The 2nd parameter of watch dog thread should be list of process, "
3342                            "but got {}.".format(type(workers)))
3343
3344        while not eot.is_set():
3345            # Monitoring and count how many subprocesses already exit
3346            clear_subprocess_timeout = _PythonMultiprocessing._monitor_subprocess_exit(workers)
3347            # If find subprocess exit, we will wait for 30s and do some waitpid operations
3348            if clear_subprocess_timeout > 0:
3349                start = time.time()
3350                while time.time() - start < clear_subprocess_timeout:
3351                    # We need to distinguishing get_dataset_size or train finished normally and hang scenario.
3352                    # If get_dataset_size or train finished normally, _stop_subprocess can be execute and
3353                    # self.need_abort can be set to True. If main process is hang in get(), self.need_abort
3354                    # will never set to True, then we wait for 30s and kill main process
3355                    if eot.is_set():
3356                        return
3357                    # Sometimes subprocess may be zombie, so in 30s we can wait and do some useful tasks(waitpid).
3358                    _PythonMultiprocessing.wait_pid()
3359                # multiprocessing.queue may hang in .get() forever when put() process was killed.
3360                # We have to exit main process otherwise main process will hang.
3361                _PythonMultiprocessing._terminate_processes(workers)
3362                logger.critical("The subprocess of dataset may exit unexpected or be killed, "
3363                                "main process will exit. If this is not an artificial operation, you can use "
3364                                "ds.config.set_enable_watchdog(False) to block this error.")
3365                os.kill(os.getpid(), signal.SIGTERM)
3366
3367        # release the workers
3368        del workers
3369
3370    @staticmethod
3371    def _terminate_processes(processes):
3372        """Terminate subprocesses"""
3373
3374        for p in processes:
3375            try:
3376                if p.exitcode is None:
3377                    p.terminate()
3378            except Exception:  # pylint: disable=broad-except
3379                # process has been closed already
3380                pass
3381        for p in processes:
3382            if p._closed is False:  # pylint: disable=W0212
3383                # We don't use w.join because join can only used in main process or join will raise an error.
3384                p._popen.wait()  # pylint: disable=W0212
3385
3386    # Monitor the exit number of subprocesses
3387    @staticmethod
3388    def _monitor_subprocess_exit(workers):
3389        """
3390        To monitor whether process is exit.
3391
3392        Args:
3393            workers (list of multiprocessing.Process): multiprocessing.Process.
3394
3395        Returns:
3396            int, the timeout(in seconds) when process exit.
3397        """
3398        for w in workers:
3399            try:
3400                exit_code = w.exitcode
3401                if exit_code is not None:
3402                    # For kill -9, we can exit quickly
3403                    if exit_code == -9:
3404                        return 1
3405                    # For kill -15, we still exit after 30s
3406                    if exit_code == -15:
3407                        return 30
3408                # In some cases the subprocess has been killed but the exitcode is still None.
3409                # So we use os.kill(pid, 0) to check if it is alive.
3410                subprocess_alive = _PythonMultiprocessing.is_process_alive(w.pid)
3411                if not subprocess_alive:
3412                    # Like kill -15, we wait 30s before exit
3413                    return 30
3414            except ValueError:
3415                # process has been closed already
3416                return 0
3417        return 0
3418
3419    @staticmethod
3420    def is_process_alive(pid):
3421        """
3422        Check if the process is alive or not.
3423        Note:  We hit a deadlock when we use psutil or w.exitcode to check whether a process is alive.
3424        Instead we use os.kill(ppid, 0).
3425
3426        Args:
3427            pid: pid of the process to be checked
3428
3429        Returns:
3430            True if the process is alive
3431        """
3432
3433        try:
3434            os.kill(pid, 0)
3435        except OSError:
3436            return False
3437        return True
3438
3439    # When main process exit, subprocesses will be terminate
3440    @staticmethod
3441    def _clean_process(ppid, workers, quit_signal):
3442        """
3443            This is the execute function of clean process, if we found main process exited, we will clean subprocesses.
3444
3445        Args:
3446            ppid: The process id of main process.
3447            workers: The list of subprocesses.
3448            quit_signal: The flag of quit.
3449        """
3450        signal.signal(signal.SIGINT, signal.SIG_IGN)
3451        while _PythonMultiprocessing.is_process_alive(ppid):
3452            if quit_signal.is_set():
3453                return
3454            time.sleep(0.1)
3455
3456        _PythonMultiprocessing._terminate_processes(workers)
3457        del workers
3458        os.kill(os.getpid(), signal.SIGTERM)
3459
3460    def launch(self, op_id=-1):
3461        """
3462        Launch Python multiprocessing pool.
3463
3464        Args:
3465            pop_id: ID for operation to have Python multiprocessing pool launched
3466
3467        Returns:
3468            Python multiprocssing pool is launched.
3469        """
3470        self.python_threads_to_workers = {}
3471        self.op_id = op_id
3472        logger.info("Launching new Python Multiprocessing pool for Op:" + str(self.op_id))
3473        if self.is_mp_enabled():
3474            message = "Launching a new Python multiprocessing pool while a pool already exists!" + \
3475                " The existing pool will be terminated first."
3476            logger.warning(message)
3477            self.terminate()
3478            self.reset()
3479        self.create_pool()
3480
3481    def create_pool(self):
3482        """
3483
3484        Returns:
3485
3486        """
3487        if get_enable_shared_mem():
3488            _check_shm_usage(self.num_parallel_workers, 1, self.max_rowsize[0], self.max_rowsize[1])
3489
3490        if self.workers is not None:
3491            raise Exception("Pool was already created, close it first.")
3492
3493        # Let gc collect unreferenced memory to avoid child processes in the pool to do it
3494        gc.collect()
3495
3496        # Construct python worker processes
3497        self.workers = []
3498        self.warning_ctl = multiprocessing.Value('i', 0)
3499        for worker_id in range(self.num_parallel_workers):
3500            worker = _MPWorker(self.operations, self.warning_ctl, self.max_rowsize, worker_id)
3501            worker.start()
3502            self.workers.append(worker)
3503
3504        logger.info("Op: " + str(self.op_id) + " Python multiprocessing pool workers' PIDs: " + str(self.get_pids()))
3505
3506        self.hook = _PythonMultiprocessing._ExceptHookHandler()
3507
3508        # The op (Map, Batch, etc) multiprocessing will launch a watch dog thread for monitoring sub processes
3509        self._launch_watch_dog()
3510
3511        atexit.register(self.terminate)
3512
3513    def terminate(self):
3514        # close watch dog first and then close all the workers
3515        self.abort_watchdog()
3516        self.close_all_workers()
3517        if hasattr(self, "warning_ctl"):
3518            del self.warning_ctl
3519
3520    def get_pids(self):
3521        """
3522        Get list of worker's PIDs
3523
3524        Returns:
3525            list of strings
3526        """
3527        if not self.is_mp_enabled():
3528            return []
3529        if not self.pids:
3530            self.pids = []
3531            if self.workers:
3532                for w in self.workers:
3533                    try:
3534                        self.pids.append(w.pid)
3535                    except ValueError:
3536                        continue
3537        return self.pids
3538
3539    def add_new_workers(self, num_new_workers):
3540        logger.info(
3541            "Increasing num_parallel_workers of Python Multiprocessing pool for Op:" + str(self.op_id) +
3542            ", old num_workers=" + str(self.num_parallel_workers) + " new num_workers=" + str(
3543                self.num_parallel_workers +
3544                num_new_workers) + ".")
3545        self.terminate()
3546        self.num_parallel_workers += num_new_workers
3547        self.launch(self.op_id)
3548
3549    def remove_workers(self, num_removed_workers):
3550        logger.info(
3551            "Decreasing num_parallel_workers of Python Multiprocessing pool for Op:" + str(self.op_id) +
3552            ", old num_workers=" + str(self.num_parallel_workers) + " new num_workers=" + str(
3553                self.num_parallel_workers -
3554                num_removed_workers) + ".")
3555        self.terminate()
3556        self.num_parallel_workers -= num_removed_workers
3557        self.launch(self.op_id)
3558
3559    def is_mp_enabled(self):
3560        return self.workers is not None
3561
3562    def execute(self, idx, *args):
3563        """
3564        Execute
3565        """
3566        t_id = threading.get_ident()
3567        # get the worker_id from Python layer cache first, get from Cpp layer if not found.
3568        worker_id = self.python_threads_to_workers.setdefault(t_id, self.get_thread_to_worker())
3569        if worker_id >= len(self.workers):
3570            raise RuntimeError("[Internal] worker_id value is greater than number of available workers!")
3571
3572        # todo check_iterator_cleanup
3573        if self.is_running() and check_iterator_cleanup() is False:
3574            return self.workers[worker_id].execute(idx, *args)
3575
3576        return None
3577
3578    def _launch_watch_dog(self):
3579        """
3580        We will launch a watchdog thread and a clean process to cleaning subprocess when there is process was killed.
3581        The watchdog thread will cleanup subprocesses and main process when one of the subprocesses was killed.
3582        The cleaning subprocess will cleanup subprocesses when main process was killed.
3583        """
3584        if platform.system().lower() != 'windows':
3585            self.eof = multiprocessing.Event()
3586            self.cleaning_process = multiprocessing.Process(target=self._clean_process,
3587                                                            name="MapCleanProcess",
3588                                                            args=(self.ppid, self.workers, self.eof),
3589                                                            daemon=True)
3590            self.cleaning_process.start()
3591
3592            if get_enable_watchdog():
3593                self.eot = threading.Event()
3594                self.watch_dog = threading.Thread(target=self._watch_dog,
3595                                                  name="MapWatchDog",
3596                                                  args=(self.eot, self.workers + [self.cleaning_process]),
3597                                                  daemon=True)
3598                self.watch_dog.start()
3599
3600    def _abort_watchdog(self):
3601        if not self.eot.is_set():
3602            self.eot.set()
3603
3604    def abort_watchdog(self):
3605        if hasattr(self, 'watch_dog') and self.watch_dog is not None and hasattr(self, 'eot') and self.eot is not None:
3606            self._abort_watchdog()
3607        if hasattr(self, 'cleaning_process') and self.cleaning_process is not None:
3608            if hasattr(self, 'eof') and self.eof is not None and not self.eof.is_set():
3609                self.eof.set()
3610            _PythonMultiprocessing._terminate_processes([self.cleaning_process])
3611            del self.cleaning_process
3612
3613    def is_running(self):
3614        if hasattr(self, 'workers') and self.workers is not None:
3615            return all([w.is_alive() for w in self.workers])
3616        return False
3617
3618    def close_all_workers(self):
3619        """Close all the subprocess workers"""
3620        if hasattr(self, 'workers') and self.workers is not None:
3621            for w in self.workers:
3622                w.close()
3623            check_interval = get_multiprocessing_timeout_interval()
3624            for w in self.workers:
3625                try:
3626                    subprocess_file_descriptor = w.sentinel
3627                    st = time.time()
3628                    while _PythonMultiprocessing.is_process_alive(w.pid):
3629                        time.sleep(0.01)  # sleep 10ms, waiting for the subprocess exit
3630                        if time.time() - st > check_interval:
3631                            logger.warning("Waiting for the subprocess worker [{}] to exit.".format(w.pid))
3632                            st += check_interval
3633                except ValueError as e:
3634                    if "process object is closed" in str(e):
3635                        continue
3636                    raise e
3637                try:
3638                    if w.is_alive():
3639                        os.close(subprocess_file_descriptor)
3640                except OSError as e:
3641                    # Maybe the file descriptor had been released, so ignore the 'Bad file descriptor'
3642                    if "Bad file descriptor" not in str(e):
3643                        raise e
3644
3645            # use clear to release the handle which is better than self.workers = None
3646            self.workers.clear()
3647            self.workers = None
3648            self.pids = None
3649
3650
3651class MapDataset(UnionBaseDataset):
3652    """
3653    The result of applying the Map operation to the input Dataset.
3654
3655    Args:
3656        input_dataset (Dataset): Input Dataset to be mapped.
3657        operations (Union[list[TensorOperation], list[functions]]): A function mapping a nested structure of tensors
3658            to another nested structure of tensor. Default: ``None``.
3659        input_columns (Union[str, list[str]]): List of names of the input columns.
3660            Default: ``None``, the operations will be applied on the first columns in the dataset.
3661            The size of the list should match the number of inputs of the first operation.
3662        output_columns (Union[str, list[str]], optional): List of names of the output columns.
3663            The size of the list should match the number of outputs of the last operation.
3664            Default: ``None``, output columns will be the input columns, i.e., the columns will
3665            be replaced.
3666        num_parallel_workers (int, optional): Number of workers to process the dataset
3667            in parallel. Default: ``None``.
3668        python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
3669            option could be beneficial if the Python operation is computational heavy. Default: ``False``.
3670        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
3671            Default: ``None``, which means no cache is used.
3672        callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called. Default: ``None``.
3673        max_rowsize(Union[int, list[int]], optional): Maximum size of row in MB that is used for shared memory
3674            allocation to copy data between processes, the total occupied shared memory will increase as
3675            ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set to -1,
3676            shared memory will be dynamically allocated with the actual size of data. This is only used if
3677            ``python_multiprocessing`` is set to True. If it is an int value, it represents ``input_columns`` and
3678            ``output_columns`` use this value as the unit to create shared memory. If it is a list, the first element
3679            represents the ``input_columns`` use this value as the unit to create shared memory, and the second element
3680            represents ``output_columns`` use this value as the unit to create shared memory. Default: 16.
3681        offload (bool, optional): Flag to indicate whether offload is used. Default: ``None``.
3682    """
3683
3684    def __init__(self, input_dataset, operations=None, input_columns=None, output_columns=None,
3685                 num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16,
3686                 offload=None):
3687        super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers, cache=cache)
3688        self.operations = to_list(operations)
3689        for op in self.operations:
3690            # user define c_vision.HWC2CHW without parentheses is error
3691            if type(op) == type:  # pylint: disable=unidiomatic-typecheck
3692                raise ValueError("Parameter operations's element of method map should be a dataset processing "
3693                                 "operation instance, but got: {}. It may be missing parentheses for "
3694                                 "instantiation.".format(op))
3695            if not isinstance(op, (c_transforms.TensorOperation, py_transforms.PyTensorOperation)) \
3696                    and not callable(op):
3697                raise ValueError("Parameter operations's element of method map should be a python function or "
3698                                 "class method which should be callable, but got: {}. It doesn't need parentheses "
3699                                 "for python function or class method.".format(op))
3700
3701        self.input_columns = to_list(input_columns)
3702        self.output_columns = to_list(output_columns)
3703
3704        #  If output_columns were not provided then use input_columns
3705        self.output_columns = self.input_columns if not self.output_columns else self.output_columns
3706
3707        self.python_multiprocessing = python_multiprocessing
3708        self.process_pool = None
3709
3710        self.callbacks = to_list(callbacks)
3711        if isinstance(max_rowsize, int):
3712            self.max_rowsize = [max_rowsize] * 2
3713        else:
3714            self.max_rowsize = max_rowsize
3715        self.offload = offload
3716
3717    def parse(self, children=None):
3718        operations = self.__decompose_callable_operations()
3719
3720        count_old_transforms, count_new_transforms, count_non_data_vision_transforms = \
3721            self.__count_transforms(operations)
3722        count_pyfunc = self.__count_pyfuncs(operations)
3723        if count_new_transforms + count_pyfunc == len(operations):
3724            prev_op = None
3725            for op in operations:
3726                # skip user added DebugHook to avoid changing to Py-implementation.
3727                if self.__is_debug_hook_op(op):
3728                    if prev_op:
3729                        # manually set previous_op_name
3730                        prev_op_name = self.__parse_op_name(prev_op)
3731                        op.set_previous_op_name(prev_op_name)
3732                    continue
3733                if op.implementation is None:
3734                    if prev_op and prev_op.implementation == Implementation.PY:
3735                        op.implementation = Implementation.PY
3736                    else:
3737                        op.implementation = Implementation.C
3738                prev_op = op
3739            operations = self.__insert_debug_wrapper(operations)
3740            operations = transforms.transforms.Compose.reduce(operations)
3741        elif count_old_transforms + count_pyfunc + count_non_data_vision_transforms == len(operations):
3742            operations = self.__insert_debug_wrapper(operations)
3743            operations = transforms.py_transforms.Compose.reduce(operations)
3744        else:
3745            raise RuntimeError("Mixing old legacy c/py_transforms and new unified transforms is not allowed.")
3746
3747        self.operations = self.__process_final_operations(operations)
3748        self.prepare_multiprocessing()
3749
3750        callbacks = [cb.create_runtime_obj() for cb in self.callbacks]
3751        return cde.MapNode(children[0], self.operations, self.input_columns, self.output_columns,
3752                           callbacks, OffloadToManualOffloadMode.get(self.offload), self.process_pool)
3753
3754    def __deepcopy__(self, memodict):
3755        return self.__safe_deepcopy__(memodict, exclude=("operations", "callbacks", "__transfer_dataset__"))
3756
3757    def __del__(self):
3758        if hasattr(self, "process_pool") and self.process_pool is not None:
3759            self.process_pool.terminate()
3760            del self.process_pool
3761
3762    @staticmethod
3763    def __parse_op_name(op):
3764        """
3765        Utility method to get operation name.
3766        """
3767        op_name = ""
3768        if isinstance(op, transforms.py_transforms_util.FuncWrapper):
3769            try:
3770                op_name = op.transform.__name__
3771            except (AttributeError,):
3772                op_name = op.transform.__class__.__name__
3773        else:
3774            op_name = op.__class__.__name__
3775        return op_name
3776
3777    @staticmethod
3778    def __construct_debug_hook(previous_op_name=None, is_first_op=False):
3779        """
3780        Wrap debug hook into FuncWrapper.
3781        """
3782        inserted_functions = []
3783        debug_hook_list = _get_debug_hook_list()
3784        if debug_hook_list:
3785            for fn in debug_hook_list:
3786                # making deep copy to allow each debug hook instance hold unique variables
3787                new_fn = copy.deepcopy(fn)
3788                new_fn.set_previous_op_name(previous_op_name)
3789                new_fn.set_is_first(is_first_op)
3790                inserted_func = transforms.py_transforms_util.FuncWrapper(new_fn)
3791                inserted_func.implementation = Implementation.PY
3792                inserted_functions.append(inserted_func)
3793        return inserted_functions
3794
3795    @staticmethod
3796    def __is_debug_hook_op(op):
3797        """
3798        Check if the op is user added DebugHook and skip it to avoid changing transforms implementation.
3799        """
3800        if isinstance(op, DebugHook):
3801            if not get_debug_mode():
3802                raise ValueError("It is not allowed to inject DebugHook object in non-debug mode.")
3803            return True
3804        return False
3805
3806    @staticmethod
3807    def __count_pyfuncs(operations):
3808        """
3809        Count the number of pyfuncs operations
3810        """
3811        return sum([1 if isinstance(op, FuncWrapper) else 0 for op in operations])
3812
3813    @staticmethod
3814    def __count_transforms(operations):
3815        """
3816        Count the various flavors of transforms operations
3817        """
3818        # Count the number of old legacy data and vision c_transforms and py_transforms
3819        count_old_transforms = sum(
3820            [1 if "c_transforms" in str(op)
3821             or isinstance(op, (c_transforms.TensorOperation, py_transforms.PyTensorOperation))
3822             or ("py_transforms" in str(op) and not isinstance(op, FuncWrapper))
3823             else 0 for op in operations])
3824        # Count the number of new unified data and vision transforms
3825        count_new_transforms = sum([1 if hasattr(op, "implementation") and not isinstance(op, FuncWrapper)
3826                                    else 0 for op in operations])
3827        # Count the number of non-data transforms and non-vision transforms
3828        count_non_data_vision_transforms = sum(
3829            [1 if "text.transforms" in str(op) or "audio.transforms" in str(op) else 0 for op in operations])
3830        return count_old_transforms, count_new_transforms, count_non_data_vision_transforms
3831
3832    @staticmethod
3833    def __operation_valid_for_multiprocessing(op):
3834        if callable(op) and str(op).find("c_transform") < 0:
3835            return True
3836        return False
3837
3838    @staticmethod
3839    def __process_final_operations(operations):
3840        """
3841        Build final list of operations
3842        """
3843        operations_fin = []
3844        for op in operations:
3845            if hasattr(op, "implementation"):
3846                if op.implementation == Implementation.C and not isinstance(op, (FuncWrapper, ToNumpy)):
3847                    operations_fin.append(op.parse())
3848                elif op.implementation == Implementation.PY:
3849                    operations_fin.append(op)
3850                elif isinstance(op, (FuncWrapper, ToNumpy)):
3851                    operations_fin.append(op)
3852                else:
3853                    raise RuntimeError("Wrong implementation")
3854            else:
3855                if op and getattr(op, 'parse', None):
3856                    operations_fin.append(op.parse())
3857                else:
3858                    operations_fin.append(op)
3859        return operations_fin
3860
3861    # Iterator bootstrap will be called on iterator construction.
3862    # A deep copy of Dataset object is created prior of iterator_bootstrap.
3863    # This method will create per iterator process pool and bind pyfunc execution to the pool.
3864    def prepare_multiprocessing(self):
3865        """
3866        Per iterator bootstrap callback.
3867        """
3868        if self.python_multiprocessing and platform.system().lower() == 'windows':
3869            logger.warning("Python multiprocessing is not supported on Windows platform.")
3870            return
3871        if self.python_multiprocessing and get_debug_mode():
3872            logger.warning("Python multiprocessing is not supported in debug mode."
3873                           " Ignoring Python multiprocessing for map operation.")
3874            return
3875        if self.python_multiprocessing:
3876            iter_specific_operations = []
3877            callable_list = []
3878
3879            # If user didn't specify num_parallel_workers, set it to default
3880            if self.num_parallel_workers is None:
3881                self.num_parallel_workers = get_num_parallel_workers()
3882
3883            # Pass #1, look for Python callables and build list
3884            for op in self.operations:
3885                # our c transforms is now callable and should not be run in Python multithreading
3886                if MapDataset.__operation_valid_for_multiprocessing(op):
3887                    callable_list.append(op)
3888
3889            if callable_list:
3890                self.process_pool = _PythonMultiprocessing(str(self), self.num_parallel_workers, callable_list,
3891                                                           self.max_rowsize)
3892                # Pass #2
3893                idx = 0
3894                for op in self.operations:
3895                    # our c transforms is now callable and should not be run in Python multithreading
3896                    if MapDataset.__operation_valid_for_multiprocessing(op):
3897                        # Wrap Python callable into _PythonCallable
3898                        iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
3899                        idx += 1
3900                    else:
3901                        # CPP ops remain the same
3902                        iter_specific_operations.append(op)
3903                self.operations = iter_specific_operations
3904
3905    def __insert_debug_wrapper(self, operations):
3906        """
3907        Insert DebuggerWrapper before and after each op if debug mode is on.
3908        """
3909        if not get_debug_mode():
3910            return operations
3911        first_op_name = self.__parse_op_name(operations[0])
3912        inserted_operations = self.__construct_debug_hook(first_op_name, is_first_op=True)
3913        for op in operations:
3914            inserted_operations.append(op)
3915            op_name = self.__parse_op_name(op)
3916            inserted_operations.extend(self.__construct_debug_hook(op_name))
3917        return inserted_operations
3918
3919    def __decompose_callable_operations(self):
3920        """
3921        Decompose operations and build list of old legacy ops which are callable
3922        """
3923        decomposed_operations = transforms.transforms.Compose.decompose(self.operations)
3924        operations = []
3925        for op in decomposed_operations:
3926            if callable(op) and not hasattr(op, "implementation") and str(op).find(
3927                    "c_transform") < 0 and not isinstance(op, c_transforms.TensorOperation) and \
3928                    not isinstance(op, py_transforms.PyTensorOperation):
3929                op = transforms.py_transforms_util.FuncWrapper(op)
3930            operations.append(op)
3931        return operations
3932
3933
3934class FilterDataset(UnionBaseDataset):
3935    """
3936    The result of applying filter predicate to the input Dataset.
3937
3938    Args:
3939        input_dataset (Dataset): Input Dataset to be mapped.
3940        predicate (callable): Python callable which returns a boolean value. If False then filter the element.
3941        input_columns (Union[str, list[str]], optional): List of names of the input columns.
3942            Default: ``None``, the predicate will be applied to all columns in the dataset.
3943        num_parallel_workers (int, optional): Number of workers to process the dataset
3944            in parallel. Default: ``None``.
3945    """
3946
3947    def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None):
3948        super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers)
3949        self.predicate = lambda *args: bool(predicate(*args))
3950        self.input_columns = to_list(input_columns)
3951
3952    def parse(self, children=None):
3953        return cde.FilterNode(children[0], self.predicate, self.input_columns)
3954
3955
3956class RepeatDataset(UnionBaseDataset):
3957    """
3958    The result of applying Repeat operation to the input Dataset.
3959
3960    Args:
3961        input_dataset (Dataset): Input Dataset to be repeated.
3962        count (int): Number of times the dataset will be repeated. Default: -1, repeat indefinitely.
3963    """
3964
3965    def __init__(self, input_dataset, count):
3966        super().__init__(children=input_dataset)
3967        self.count = replace_none(count, -1)
3968
3969    def parse(self, children=None):
3970        return cde.RepeatNode(children[0], self.count)
3971
3972
3973class SkipDataset(UnionBaseDataset):
3974    """
3975    The result of applying Skip operation to the input Dataset.
3976
3977    Args:
3978        input_dataset (Dataset): Input dataset to have elements skipped.
3979        count (int): Number of elements to be skipped in the dataset.
3980    """
3981
3982    def __init__(self, input_dataset, count):
3983        super().__init__(input_dataset)
3984        self.count = count
3985
3986    def parse(self, children=None):
3987        return cde.SkipNode(children[0], self.count)
3988
3989
3990class TakeDataset(UnionBaseDataset):
3991    """
3992    The result of applying Take operation to the input Dataset.
3993
3994    Args:
3995        input_dataset (Dataset): Input Dataset to have elements taken from.
3996        count (int): Number of elements to be taken from the dataset.
3997    """
3998
3999    def __init__(self, input_dataset, count):
4000        super().__init__(children=input_dataset)
4001        self.count = count
4002
4003    def parse(self, children=None):
4004        return cde.TakeNode(children[0], self.count)
4005
4006
4007class ZipDataset(UnionBaseDataset):
4008    """
4009    The result of applying Zip operation to the input Dataset.
4010
4011    Args:
4012        datasets (tuple): A tuple of datasets to be zipped together.
4013
4014    Raises:
4015        TypeError: If dataset is not an instance of Dataset.
4016    """
4017
4018    def __init__(self, datasets):
4019        super().__init__(children=datasets)
4020
4021    def parse(self, children=None):
4022        return cde.ZipNode(children)
4023
4024    def is_sync(self):
4025        return any([c.is_sync() for c in self.children])
4026
4027
4028class ConcatDataset(UnionBaseDataset):
4029    """
4030    The result of applying Concat operation to the input Dataset.
4031
4032    Args:
4033        datasets (list): A list of datasets to be concatenated together.
4034
4035    Raises:
4036        TypeError: If dataset is not an instance of Dataset.
4037        ValueError: If there is no samples in the one of the datasets.
4038    """
4039
4040    def __init__(self, datasets):
4041        super().__init__(children=datasets)
4042        for dataset in datasets:
4043            if not isinstance(dataset, Dataset):
4044                raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset))
4045        self.datasets = datasets
4046        self._sampler = samplers.SequentialSampler(num_samples=None)
4047
4048        self.children_sizes_ = [c.get_dataset_size() for c in self.children]
4049        child_index = 0
4050        for item in self.children_sizes_:
4051            if item == 0:
4052                raise ValueError("There are no samples in the dataset number %d. Please make sure there are "
4053                                 "valid samples in the dataset." % child_index)
4054            child_index += 1
4055
4056        self._children_sizes = self.children_sizes_.copy()
4057
4058        # _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
4059        # whether the dataset is mappable. The second element of pair is length of the dataset
4060        self._children_flag_and_nums = []
4061
4062        # _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
4063        # the valid position of the dataset corresponding to the subscript when sampling
4064        self._children_start_end_index_ = []
4065        for index, child in enumerate(self.children):
4066            tem_list = [-1, -1]
4067            self._children_start_end_index_.append(tem_list)
4068            dataset_len = self.children_sizes_[index]
4069
4070            from mindspore.dataset.engine.datasets_user_defined import GeneratorDataset
4071            if isinstance(child, GeneratorDataset) and not hasattr(child.source, "__getitem__"):
4072                dataset_len = 0
4073                self.children_sizes_[index] = 0
4074
4075            if isinstance(child, MappableDataset):
4076                self._children_flag_and_nums.append((0, dataset_len))
4077            else:
4078                self._children_flag_and_nums.append((1, dataset_len))
4079
4080    def parse(self, children=None):
4081        return cde.ConcatNode(children, self._sampler, self._children_flag_and_nums, self._children_start_end_index_,
4082                              self._children_sizes)
4083
4084    def use_sampler(self, sampler):
4085        """
4086        Set the distributedSampler to concat dataset
4087
4088        Args:
4089            sampler (Sampler): The sampler to use for the current dataset.
4090                Currently supported: DistributedSampler.
4091
4092        Raises:
4093            TypeError: If the sampler is not an instance of DistributedSampler
4094            ValueError: If the parameter shuffle of sampler is True
4095            ValueError: If the parameter NumSamples of sampler is not None.
4096            ValueError: If num_shards <=0.
4097        """
4098        if not isinstance(sampler, (samplers.DistributedSampler, samplers.RandomSampler)):
4099            raise TypeError("The parameter %s of concat must be DistributedSampler or RandomSampler!" % sampler)
4100
4101        if isinstance(sampler, samplers.RandomSampler):
4102            if sampler.replacement:
4103                raise ValueError("The parameter replacement of RandomSampler must be False!")
4104
4105            if sampler.get_num_samples() is not None:
4106                raise ValueError("The parameter num_samples of RandomSampler is not support to be set!")
4107
4108            self._sampler = sampler
4109            self._children_sizes = [c.get_dataset_size() for c in self.children]
4110
4111            # Recursive access to other child concat nodes
4112            def set_child(node):
4113                for c in node.children:
4114                    if isinstance(c, ConcatDataset):
4115                        c.use_sampler(sampler)
4116                    set_child(c)
4117            set_child(self)
4118
4119            return
4120
4121        if sampler.is_shuffled():
4122            raise ValueError("The parameter shuffle of DistributedSampler must be False!")
4123
4124        if sampler.num_shards <= 0:
4125            raise ValueError("The parameter num_shards of DistributedSampler must be positive int!")
4126
4127        if sampler.get_num_samples() is not None:
4128            raise ValueError("The parameter num_samples of DistributedSampler is not support to be set!")
4129
4130        self.dataset_size = None
4131
4132        self._sampler = sampler
4133        cumulative_samples_nums = 0
4134        for index, child in enumerate(self.children):
4135            if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None:
4136                raise ValueError("The parameter NumSamples of %s is not support to be set!" % child)
4137
4138            if isinstance(child, (BatchDataset, PaddedBatchDataset)):
4139                raise TypeError("The parameter %s of concat must not be BatchDataset or PaddedBatchDataset!" % child)
4140
4141            # if child is mappable and the length is greater than 0
4142            if not self._children_flag_and_nums[index][0] and self._children_flag_and_nums[index][1]:
4143
4144                tem_value = cumulative_samples_nums + self._children_flag_and_nums[index][1]
4145
4146                if not self._children_flag_and_nums[index][1] >= sampler.num_shards:
4147                    if tem_value < sampler.num_shards:
4148                        self._children_start_end_index_[index][0] = cumulative_samples_nums
4149                        self._children_start_end_index_[index][1] = tem_value
4150                    else:
4151                        self._children_start_end_index_[index][0] = cumulative_samples_nums
4152                        self._children_start_end_index_[index][1] = tem_value % sampler.num_shards
4153
4154                tem_sampler = copy.deepcopy(sampler)
4155                tem_sampler.set_offset(cumulative_samples_nums)
4156                child.use_sampler(tem_sampler)
4157
4158            cumulative_samples_nums += self.children_sizes_[index]
4159            cumulative_samples_nums %= sampler.num_shards
4160
4161
4162class RenameDataset(UnionBaseDataset):
4163    """
4164    The result of applying Rename operation to the input Dataset.
4165
4166    Args:
4167        input_dataset (Dataset): Input Dataset to be Renamed.
4168        input_columns (Union[str, list[str]]): List of names of the input columns.
4169        output_columns (Union[str, list[str]]): List of names of the output columns.
4170    """
4171
4172    def __init__(self, input_dataset, input_columns, output_columns):
4173        super().__init__(children=input_dataset)
4174        self.input_column_names = to_list(input_columns)
4175        self.output_column_names = to_list(output_columns)
4176
4177    def parse(self, children=None):
4178        return cde.RenameNode(children[0], self.input_column_names, self.output_column_names)
4179
4180
4181def to_list(items):
4182    if items is None:
4183        return []
4184    if isinstance(items, tuple):
4185        return list(items)
4186    if not isinstance(items, list):
4187        return [items]
4188    return items
4189
4190
4191class ProjectDataset(UnionBaseDataset):
4192    """
4193    The result of applying Project operation to the input Dataset.
4194
4195    Args:
4196        input_dataset (Dataset): Input Dataset to be Projected.
4197        columns (Union[str, list[str]]): List of names of the columns to project.
4198    """
4199
4200    def __init__(self, input_dataset, columns):
4201        super().__init__(children=input_dataset)
4202        self.columns = to_list(columns)
4203
4204    def parse(self, children=None):
4205        return cde.ProjectNode(children[0], self.columns)
4206
4207
4208class _ToDevice:
4209    """
4210    Internal class to handle sending data to device.
4211    """
4212
4213    def __init__(self, dataset, num_epochs):
4214        if get_debug_mode():
4215            logger.error("MindData debugger cannot be used in dataset sink mode. Please manually turn off "
4216                         "sink mode and try debugger again.")
4217        ir_tree, self.api_tree = dataset.create_ir_tree()
4218
4219        self._runtime_context = cde.PythonRuntimeContext()
4220        self._runtime_context.Init()
4221        self._to_device = cde.ToDevice(num_epochs)
4222        if dataset.get_init_step() != 0:
4223            init_step = dataset.get_init_step()
4224            dataset_size = dataset.get_dataset_size()
4225            self._to_device.Init(ir_tree, init_step, dataset_size)
4226        else:
4227            self._to_device.Init(ir_tree, 0, -1)
4228        self._runtime_context.AssignConsumer(self._to_device)
4229
4230        ITERATORS_LIST.append(weakref.ref(self))
4231        _unset_iterator_cleanup()
4232
4233    def send(self):
4234        self._to_device.Send()
4235
4236    def stop_send(self):
4237        """
4238        send stop send signal to pipeline, it is used when end of sequence is sent at the epoch end.
4239        """
4240        self._to_device.StopSend()
4241
4242    def continue_send(self):
4243        """
4244        send continue send signal to pipeline, it is used when end of sequence is sent at the epoch end.
4245        """
4246        self._to_device.ContinueSend()
4247
4248    def get_data_info(self):
4249        """
4250        Get type and shape of current batch.
4251        """
4252        return self._to_device.GetDataInfo()
4253
4254    def get_mbuf_queue_size(self):
4255        """
4256        Get element numbers inside mbuf.
4257        """
4258        return self._to_device.GetMbufQueueSize()
4259
4260    def get_send_info(self):
4261        """
4262        In sink mode, it returns the send information of dataset at this moment.
4263        Send information includes number of send batches, time summary of fetching data on host
4264        and time summary of sending data.
4265        """
4266        return self._to_device.GetSendInfo()
4267
4268    def release(self):
4269        """
4270        Manually terminate Device Queue instead of relying on out of scope destruction.
4271        """
4272        if hasattr(self, '_runtime_context') and self._runtime_context:
4273            if hasattr(self, '_to_device') and self._to_device:
4274                self._runtime_context.Terminate()
4275                del self._to_device
4276            del self._runtime_context
4277
4278    def __deepcopy__(self, memodict):
4279        return self
4280
4281    def get_offload_model(self, col_names):
4282        """
4283        Get offload model containing removed offload ops from pipeline.
4284        """
4285        offload_model = GetOffloadModel(self._to_device, col_names)
4286        return offload_model
4287
4288    def _reset(self, step, dataset_size):
4289        self._to_device.Reset(step, dataset_size)
4290
4291
4292class TransferDataset(Dataset):
4293    """
4294    The result of applying TDT operation to the input Dataset.
4295
4296    Args:
4297        input_dataset (Dataset): Input Dataset to be transferred.
4298        send_epoch_end (bool, optional): Whether to send end of sequence to device or not. Default: ``True``.
4299        create_data_info_queue (bool, optional): Whether to create queue which stores
4300            types and shapes of data or not. Default: ``False``.
4301
4302    Raises:
4303        TypeError: If device_type is empty.
4304        ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'.
4305        RuntimeError: If dataset is unknown.
4306    """
4307
4308    def __init__(self, input_dataset, send_epoch_end=True, create_data_info_queue=False, queue_name=""):
4309        super().__init__(children=input_dataset)
4310        if queue_name == "":
4311            self.queue_name = str(uuid.uuid1())
4312            logger.info(f"queue_name is newly generated. value is {self.queue_name}")
4313        else:
4314            self.queue_name = queue_name
4315            logger.info(f"queue_name is read from compile cache. value is {self.queue_name}")
4316        self.device_type = context.get_context("device_target") if context else "CPU"
4317        self.device_id = context.get_context("device_id") if context else 0
4318
4319        self._send_epoch_end = replace_none(send_epoch_end, True)
4320        self._create_data_info_queue = create_data_info_queue
4321        self._to_device = None
4322        self.column_name = input_dataset.get_col_names()
4323
4324    def parse(self, children=None):
4325        total_batch = 0
4326        if hasattr(self.children[0], "__total_batch__"):
4327            total_batch = self.children[0].__total_batch__
4328            check_total_batch(total_batch)
4329        return cde.DataQueueNode(children[0], self.queue_name, self.device_type, self.device_id, self._send_epoch_end,
4330                                 total_batch, self._create_data_info_queue)
4331
4332    def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
4333        raise RuntimeError("TransferDataset is not iterable.")
4334
4335    def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True):
4336        raise RuntimeError("TransferDataset is not iterable.")
4337
4338    def __iter__(self):
4339        raise RuntimeError("TransferDataset is not iterable.")
4340
4341    def output_shapes(self):
4342        raise RuntimeError("TransferDataset does not support obtaining output_shapes.")
4343
4344    def output_types(self):
4345        raise RuntimeError("TransferDataset does not support obtaining output_types.")
4346
4347    @check_to_device_send
4348    def send(self, num_epochs=-1):
4349        """
4350        Send to device
4351        """
4352        if Dataset._noop_mode():
4353            return
4354        if self._to_device is not None:
4355            del self._to_device
4356        self._to_device = _ToDevice(self, num_epochs)
4357        self._to_device.send()
4358
4359    def stop_send(self):
4360        if self._to_device is not None:
4361            self._to_device.stop_send()
4362
4363    def continue_send(self):
4364        if self._to_device is not None:
4365            self._to_device.continue_send()
4366
4367    def get_data_info(self):
4368        """
4369        Get type and shape of current batch
4370        """
4371        if self._to_device is not None:
4372            return self._to_device.get_data_info()
4373        raise RuntimeError("Calling get_data_info with bad state.")
4374
4375    def get_mbuf_queue_size(self):
4376        """
4377        Get element numbers inside mbuf.
4378        """
4379        if self._to_device is not None:
4380            return self._to_device.get_mbuf_queue_size()
4381        raise RuntimeError("Device queue is not init, call get_mbuf_queue_size failed.")
4382
4383    def get_send_info(self):
4384        """
4385        In sink mode, it returns the send information of dataset at this moment.
4386        Send information includes number of send batches, time summary of fetching data on host
4387        and time summary of sending data.
4388        """
4389        if self._to_device is not None:
4390            return self._to_device.get_send_info()
4391        raise RuntimeError("Calling get_send_info with bad state, data queue is not initialized.")
4392
4393    def get_offload_model(self):
4394        if self._to_device is not None:
4395            return self._to_device.get_offload_model(self.column_name)
4396
4397        raise RuntimeError("get_offload_model, _to_device is None")
4398
4399    def release(self):
4400        """
4401        Manually terminate Device Queue instead of relying on out of scope destruction.
4402        """
4403        if self._to_device is not None:
4404            self._to_device.release()
4405
4406    def _reset(self, step, dataset_size):
4407        if self._to_device is not None:
4408            logger.info("Reset the dataset pipeline to step: " + str(step) + ", epoch: " + str(step // dataset_size))
4409            self._to_device._reset(step, dataset_size)  # pylint: disable=protected-access
4410
4411
4412class Schema:
4413    """
4414    Class to represent a schema of a dataset.
4415
4416    Args:
4417        schema_file (str): Path of the schema file. Default: ``None``.
4418
4419    Raises:
4420        RuntimeError: If schema file failed to load.
4421
4422    Examples:
4423        >>> import mindspore.dataset as ds
4424        >>> from mindspore import dtype as mstype
4425        >>>
4426        >>> # Create schema; specify column name, mindspore.dtype and shape of the column
4427        >>> schema = ds.Schema()
4428        >>> schema.add_column(name='col1', de_type=mstype.int64, shape=[2])
4429    """
4430
4431    @check_schema
4432    def __init__(self, schema_file=None):
4433        self.schema_file = replace_none(schema_file, "")
4434        self.cpp_schema = cde.SchemaObj(self.schema_file)
4435
4436    @check_add_column
4437    def add_column(self, name, de_type, shape=None):
4438        """
4439        Add new column to the schema.
4440
4441        Args:
4442            name (str): The new name of the column.
4443            de_type (str): Data type of the column.
4444            shape (list[int], optional): Shape of the column.
4445                Default: ``None``, [-1] which is an unknown shape of rank 1.
4446
4447        Raises:
4448            ValueError: If column type is unknown.
4449
4450        Examples:
4451            >>> import mindspore.dataset as ds
4452            >>> from mindspore import dtype as mstype
4453            >>>
4454            >>> schema = ds.Schema()
4455            >>> schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
4456        """
4457        if isinstance(de_type, typing.Type):
4458            de_type = mstype_to_detype(de_type)
4459            col_type = str(de_type)
4460        else:
4461            col_type = str(cde.DataType(de_type))
4462        if shape is None:
4463            self.cpp_schema.add_column(name, col_type)
4464        else:
4465            self.cpp_schema.add_column(name, col_type, shape)
4466
4467    def parse_columns(self, columns):
4468        """
4469        Parse the columns and add it to self.
4470
4471        Args:
4472            columns (Union[dict, list[dict], tuple[dict]]): Dataset attribute information, decoded from schema file.
4473
4474                - list[dict], `name` and `type` must be in keys, `shape` optional.
4475
4476                - dict, columns.keys() as name, columns.values() is dict, and `type` inside, `shape` optional.
4477
4478        Raises:
4479            RuntimeError: If failed to parse columns.
4480            RuntimeError: If column's name field is missing.
4481            RuntimeError: If column's type field is missing.
4482
4483        Examples:
4484            >>> from mindspore.dataset import Schema
4485            >>> schema = Schema()
4486            >>> columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]},
4487            ...             {'name': 'label', 'type': 'int8', 'shape': [1]}]
4488            >>> schema.parse_columns(columns1)
4489            >>> columns2 = {'image': {'shape': [3, 3], 'type': 'int8'}, 'label': {'shape': [1], 'type': 'int8'}}
4490            >>> schema.parse_columns(columns2)
4491        """
4492        self.cpp_schema.parse_columns(json.dumps(columns, indent=2))
4493
4494    def to_json(self):
4495        """
4496        Get a JSON string of the schema.
4497
4498        Returns:
4499            str, JSON string of the schema.
4500
4501        Examples:
4502            >>> from mindspore.dataset import Schema
4503            >>> from mindspore import dtype as mstype
4504            >>>
4505            >>> schema = Schema()
4506            >>> schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
4507            >>> json = schema.to_json()
4508        """
4509        return self.cpp_schema.to_json()
4510
4511    def from_json(self, json_obj):
4512        """
4513        Get schema file from JSON object.
4514
4515        Args:
4516            json_obj(dictionary): Object of JSON parsed.
4517
4518        Raises:
4519            RuntimeError: if there is unknown item in the object.
4520            RuntimeError: if dataset type is missing in the object.
4521            RuntimeError: if columns are missing in the object.
4522
4523        Examples:
4524            >>> import json
4525            >>> from mindspore.dataset import Schema
4526            >>>
4527            >>> with open("/path/to/schema_file", "r") as file:
4528            ...     json_obj = json.load(file)
4529            ...     schema = Schema()
4530            ...     schema.from_json(json_obj)
4531        """
4532        self.cpp_schema.from_string(json.dumps(json_obj, indent=2))
4533
4534    def __str__(self):
4535        return self.to_json()
4536
4537    @staticmethod
4538    def get_num_rows(schema):
4539        schema_obj = schema
4540        if not isinstance(schema_obj, Schema):
4541            schema_obj = Schema(schema_obj)
4542        return schema_obj.cpp_schema.get_num_rows()
4543
4544
4545class DeserializedDataset(Dataset):
4546    def __init__(self, input_obj):
4547        super().__init__()
4548        self.input_obj = input_obj
4549
4550    def parse(self, children=None):
4551        if isinstance(self.input_obj, dict):
4552            json_str = json.dumps(self.input_obj)
4553            return cde.Dataset.from_json_string(json_str)
4554        return cde.Dataset.from_json_file(self.input_obj)
4555