• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16This dataset module supports various formats of datasets, including ImageNet, TFData,
17MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with
18high performance and parses data precisely. Some of the operations that are
19provided to users to preprocess data include shuffle, batch, repeat, map, and zip.
20"""
21import atexit
22import glob
23import json
24import math
25import os
26import signal
27import stat
28import time
29import uuid
30import multiprocessing
31from multiprocessing.pool import RUN
32import queue
33from enum import Enum
34from functools import partial
35from importlib import import_module
36import sys
37import threading
38
39import copy
40import weakref
41import platform
42import psutil
43import numpy as np
44from scipy.io import loadmat
45from PIL import Image
46
47import mindspore._c_dataengine as cde
48from mindspore._c_expression import typing
49
50from mindspore import Tensor
51from mindspore import log as logger
52from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
53from mindspore.parallel._utils import _get_device_num
54
55import mindspore.dataset.transforms.py_transforms as py_transforms
56
57from . import samplers
58from .iterators import DictIterator, TupleIterator, DummyIterator, check_iterator_cleanup, _set_iterator_cleanup, \
59    ITERATORS_LIST, _unset_iterator_cleanup
60from .queue import _SharedQueue
61from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
62    check_rename, check_numpyslicesdataset, check_device_send, check_take, check_project, check_imagefolderdataset, \
63    check_mnist_cifar_dataset, check_manifestdataset, check_tfrecorddataset, check_vocdataset, check_cocodataset, \
64    check_celebadataset, check_minddataset, check_generatordataset, check_sync_wait, check_zip_dataset, \
65    check_add_column, check_textfiledataset, check_concat, check_random_dataset, check_split, \
66    check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
67    check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
68    check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
69    check_sbu_dataset
70from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
71    get_prefetch_size
72from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
73from ..core.validator_helpers import replace_none
74from ..core.py_util_helpers import ExceptionHandler
75from ..transforms.py_transforms_util import FuncWrapper
76
77try:
78    context = import_module("mindspore.context")
79except ModuleNotFoundError:
80    context = None
81
82
83class Shuffle(str, Enum):
84    GLOBAL: str = "global"
85    FILES: str = "files"
86    INFILE: str = "infile"
87
88
89ShuffleToShuffleMode = {Shuffle.FILES: cde.ShuffleMode.FILES,
90                        Shuffle.GLOBAL: cde.ShuffleMode.GLOBAL,
91                        Shuffle.INFILE: cde.ShuffleMode.INFILE}
92
93
94def shuffle_to_shuffle_mode(shuffle):
95    """class Shuffle Enum to int"""
96    shuffle_mode = cde.ShuffleMode.GLOBAL  # Global shuffle
97    if not isinstance(shuffle, Shuffle):
98        if shuffle is None or shuffle:
99            shuffle_mode = cde.ShuffleMode.GLOBAL  # Global shuffle
100        else:
101            shuffle_mode = cde.ShuffleMode.FALSE  # No shuffle
102    else:
103        shuffle_mode = ShuffleToShuffleMode[shuffle]
104    return shuffle_mode
105
106
107def shuffle_to_bool(shuffle):
108    """class Shuffle Enum to bool"""
109    shuffle_bool = True
110    if not isinstance(shuffle, Shuffle):
111        if shuffle is None:
112            shuffle_bool = None
113        elif shuffle:
114            shuffle_bool = True
115        else:
116            shuffle_bool = False
117    else:
118        shuffle_bool = True
119    return shuffle_bool
120
121
122@check_zip
123def zip(datasets):
124    """
125    Zip the datasets in the input tuple of datasets.
126
127    Args:
128        datasets (tuple of class Dataset): A tuple of datasets to be zipped together.
129            The number of datasets must be more than 1.
130
131    Returns:
132        ZipDataset, dataset zipped.
133
134    Raises:
135        ValueError: If the number of datasets is 1.
136        TypeError: If datasets is not a tuple.
137
138    Examples:
139            >>> # Create a dataset which is the combination of dataset_1 and dataset_2
140            >>> dataset = ds.zip((dataset_1, dataset_2))
141    """
142    if len(datasets) <= 1:
143        raise ValueError(
144            "Can't zip empty or just one dataset!")
145    for dataset in datasets:
146        if not isinstance(dataset, Dataset):
147            raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset))
148    return ZipDataset(datasets)
149
150
151def _get_operator_process():
152    """
153    Inner implemented method, mainly for passing sub-process id in C layer
154
155    Returns:
156         dict, mapping dict of operator id and corresponding process id.
157    """
158    global _OP_PROCESS
159    process_info = _OP_PROCESS
160    op_process = dict()
161    keys = process_info.keys()
162    fetched_all = True
163    for key in keys:
164        op_process[key] = list(process_info[key][1])
165        item_full = (len(process_info[key][1]) == process_info[key][0])
166        fetched_all = fetched_all and item_full
167    return op_process, fetched_all
168
169
170def _set_dataset_permissions(file_name, num_files):
171    """
172    set saved dataset files' permissions to 600
173    the rule of dataset filenames should be the same as those in C++.
174    """
175    num_digits = len(str(num_files - 1))
176    if num_files == 1:
177        paths = [file_name]
178    else:
179        paths = ["{}{}".format(file_name, str(x).rjust(num_digits, '0')) for x in range(num_files)]
180
181    for item in paths:
182        if os.path.exists(item):
183            os.chmod(item, stat.S_IRUSR | stat.S_IWUSR)
184            index_file = item + ".db"
185            if os.path.exists(index_file):
186                os.chmod(index_file, stat.S_IRUSR | stat.S_IWUSR)
187
188
189class Dataset:
190    """
191    Abstract class to represent a dataset in DataEngine's data pipeline.
192
193    This class is the base class of SourceDataset and Dataset, and represents
194    a node in the data flow graph.
195
196    Args:
197        num_parallel_workers (int, optional): Number of workers to process the dataset in parallel
198            (default=None).
199    """
200
201    def __init__(self, children=None, num_parallel_workers=None, cache=None):
202        # Note: children and parent are internal variables, not recommended for external using.
203        self.children = replace_none(children, [])
204        if isinstance(self.children, tuple):
205            self.children = list(self.children)
206        if not isinstance(self.children, list):
207            self.children = [self.children]
208
209        self.parent = []
210        for child in self.children:
211            child.parent.append(weakref.ref(self))
212        self.num_parallel_workers = num_parallel_workers
213        self.cache = cache
214
215        self._device_iter = 0
216        self._input_indexs = ()
217        self.saved_output_types = None
218        self.saved_output_shapes = None
219        self.dynamic_setting = [False, None]
220        self.saved_min_shapes = None
221        self.saved_max_shapes = None
222        self._col_names = None
223        self.dataset_size = None
224        self._batch_size = None
225        self._num_classes = None
226        self._repeat_count = None
227        self._class_indexing = None
228        self._sync = False
229
230    def create_ir_tree(self):
231        """
232        Internal method to build an IR tree.
233
234        Returns:
235            DatasetNode, the root node of the IR tree.
236            Dataset, the root dataset of the IR tree.
237        """
238        parent = self.parent
239        self.parent = []
240        dataset = copy.deepcopy(self)
241        global _OP_NAME
242        _OP_NAME = Dataset._get_operator_id(dataset)
243        ir_tree = dataset.parse_tree()
244        self.parent = parent
245        _init_device_info()
246        return ir_tree, dataset
247
248    def close_pool(self):
249        """
250        Close multiprocessing pool in dataset. If you are familiar with multiprocessing library, you can regard this
251        as a destructor for a processingPool object.
252        """
253        if hasattr(self, 'process_pool') and self.process_pool is not None:
254            self.process_pool.close()
255        for child in self.children:
256            child.close_pool()
257
258    def notify_watchdog(self):
259        if hasattr(self, 'sample_fn') and self.sample_fn is not None:
260            if self.sample_fn.multi_process:
261                self.sample_fn._abort_watchdog()  # pylint: disable=W0212
262        if hasattr(self, 'watch_dog') and self.watch_dog is not None and hasattr(self, 'eot') and self.eot is not None:
263            self._abort_watchdog()
264        for child in self.children:
265            child.notify_watchdog()
266
267    @staticmethod
268    def _get_operator_id(dataset):
269        """
270        Internal method to iterate the tree and obtain op_id of each operator.
271
272        Returns:
273            Dataset, the root dataset of the tree.
274        """
275        op_name = dict()
276        generator_process = dict()
277        op_name[str(dataset)] = 0
278        op_id = 1
279
280        def process_name(datasets, operator_id):
281            if not datasets:
282                return 0
283            temp = []
284            for item in datasets:
285                for d in item.children:
286                    temp.append(d)
287                    op_name[str(d)] = operator_id
288                    if isinstance(d, GeneratorDataset) and d.sample_fn and d.sample_fn.pids:
289                        generator_process[operator_id] = [d.num_parallel_workers, set(d.sample_fn.pids)]
290
291            operator_id = operator_id + 1
292            return process_name(temp, operator_id)
293
294        process_name([dataset], op_id)
295        if generator_process:
296            global _OP_PROCESS
297            _OP_PROCESS.update(generator_process)
298        return op_name
299
300    def parse_tree(self):
301        """
302        Internal method to parse the API tree into an IR tree.
303
304        Returns:
305            DatasetNode, the root node of the IR tree.
306        """
307        if len(self.parent) > 1:
308            raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
309        ir_children = [d.parse_tree() for d in self.children]
310        # Bootstrap can only be performed on a copy of the original dataset node.
311        # Bootstrap on original dataset node will make all iterators share the same process pool
312        self.iterator_bootstrap()
313        ir_node = self.parse(ir_children)
314        ir_node = self.post_parse(ir_node)
315        return ir_node
316
317    def __safe_deepcopy__(self, memodict, exclude=()):
318        if id(self) in memodict:
319            return memodict[id(self)]
320        cls = self.__class__
321        new_op = cls.__new__(cls)
322        memodict[id(self)] = new_op
323        for arg, value in self.__dict__.items():
324            if arg in exclude:
325                setattr(new_op, arg, value)
326            else:
327                try:
328                    setattr(new_op, arg, copy.deepcopy(value, memodict))
329                except TypeError:
330                    setattr(new_op, arg, value)
331        return new_op
332
333    def iterator_bootstrap(self):
334        pass
335
336    @staticmethod
337    def _noop_mode():
338        if _is_role_sched() or _is_role_pserver():
339            return True
340        return False
341
342    def __add__(self, datasets):
343        return self.concat(datasets)
344
345    def to_json(self, filename=""):
346        """
347        Serialize a pipeline into JSON string and dump into file if filename is provided.
348
349        Args:
350            filename (str): filename of JSON file to be saved as.
351
352        Returns:
353            str, JSON string of the pipeline.
354        """
355        ir_tree, _ = self.create_ir_tree()
356        return json.loads(ir_tree.to_json(filename))
357
358    @check_bucket_batch_by_length
359    def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes, element_length_function=None,
360                               pad_info=None, pad_to_bucket_boundary=False, drop_remainder=False):
361        """
362        Bucket elements according to their lengths. Each bucket will be padded and batched when
363        they are full.
364
365        A length function is called on each row in the dataset. The row is then
366        bucketed based on its length and bucket boundaries. When a bucket reaches its
367        corresponding size specified in bucket_batch_sizes, the entire bucket will be
368        padded according to batch_info, and then form a batch.
369        Each batch will be full, except one special case: the last batch for each bucket may not be full.
370
371        Args:
372            column_names (list[str]): Columns passed to element_length_function.
373            bucket_boundaries (list[int]): A list consisting of the upper boundaries
374                of the buckets. Must be strictly increasing. If there are n boundaries,
375                n+1 buckets are created: One bucket for [0, bucket_boundaries[0]), one
376                bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each
377                0<i<n-1, and last bucket for [bucket_boundaries[n-1], inf).
378            bucket_batch_sizes (list[int]): A list consisting of the batch sizes for
379                each bucket. Must contain len(bucket_boundaries)+1 elements.
380            element_length_function (Callable, optional): A function that takes in
381                M arguments where M = len(column_names) and returns an integer. If no value
382                provided, parameter M the len(column_names) must be 1, and the size of the first
383                dimension of that column will be taken as the length (default=None).
384            pad_info (dict, optional): The information about how to batch each column. The key
385                corresponds to the column name, and the value must be a tuple of 2 elements.
386                The first element corresponds to the shape to pad to, and the second
387                element corresponds to the value to pad with. If a column is not
388                specified, then that column will be padded to the longest in the current
389                batch, and 0 will be used as the padding value. Any None dimensions will
390                be padded to the longest in the current batch, unless if
391                pad_to_bucket_boundary is True. If no padding is wanted, set pad_info
392                to None (default=None).
393            pad_to_bucket_boundary (bool, optional): If True, will pad each None
394                dimension in pad_info to the bucket_boundary minus 1. If there are any
395                elements that fall into the last bucket, an error will occur
396                (default=False).
397            drop_remainder (bool, optional): If True, will drop the last batch for each
398                bucket if it is not a full batch (default=False).
399
400        Returns:
401            BucketBatchByLengthDataset, dataset bucketed and batched by length.
402
403        Examples:
404            >>> # Create a dataset where certain counts rows are combined into a batch
405            >>> # and drops the last incomplete batch if there is one.
406            >>> import numpy as np
407            >>> def generate_2_columns(n):
408            ...     for i in range(n):
409            ...         yield (np.array([i]), np.array([j for j in range(i + 1)]))
410            >>>
411            >>> column_names = ["col1", "col2"]
412            >>> dataset = ds.GeneratorDataset(generate_2_columns(8), column_names)
413            >>> bucket_boundaries = [5, 10]
414            >>> bucket_batch_sizes = [2, 1, 1]
415            >>> element_length_function = (lambda col1, col2: max(len(col1), len(col2)))
416            >>> # Will pad col2 to shape [bucket_boundaries[i]] where i is the
417            >>> # index of the bucket that is currently being batched.
418            >>> pad_info = {"col2": ([None], -1)}
419            >>> pad_to_bucket_boundary = True
420            >>> dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
421            ...                                          bucket_batch_sizes,
422            ...                                          element_length_function, pad_info,
423            ...                                          pad_to_bucket_boundary)
424        """
425        return BucketBatchByLengthDataset(self, column_names, bucket_boundaries, bucket_batch_sizes,
426                                          element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder)
427
428    @check_batch
429    def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
430              input_columns=None, output_columns=None, column_order=None, pad_info=None, python_multiprocessing=False,
431              max_rowsize=16):
432        """
433        Combine batch_size number of consecutive rows into batches.
434
435        For any child node, a batch is treated as a single row.
436        For any column, all the elements within that column must have the same shape.
437        If a per_batch_map callable is provided, it will be applied to the batches of tensors.
438
439        Note:
440            The order of using repeat and batch reflects the number of batches and per_batch_map.
441            It is recommended that the repeat operation applied after the batch operation finished.
442
443        Args:
444            batch_size (int or function): The number of rows each batch is created with. An
445                int or callable object which takes exactly 1 parameter, BatchInfo.
446            drop_remainder (bool, optional): Determines whether or not to drop the last block
447                whose data row number is less than batch size (default=False). If True, and if there are less
448                than batch_size rows available to make the last batch, then those rows will
449                be dropped and not propagated to the child node.
450            num_parallel_workers (int, optional): Number of workers(threads) to process the dataset in parallel
451                (default=None).
452            per_batch_map (callable, optional): Per batch map callable. A callable which takes
453                (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch
454                of Tensors on a given column. The number of lists should match with number of entries in input_columns.
455                The last parameter of the callable should always be a BatchInfo object. Per_batch_map should return
456                (list[Tensor], list[Tensor], ...). The length of each list in output should be same as the input.
457                output_columns is required if the number of output lists is different from input.
458            input_columns (Union[str, list[str]], optional): List of names of the input columns. The size of the list
459                should match with signature of per_batch_map callable (default=None).
460            output_columns (Union[str, list[str]], optional): List of names assigned to the columns
461                outputted by the last operation. This parameter is mandatory if len(input_columns) !=
462                len(output_columns). The size of this list must match the number of output
463                columns of the last operation. (default=None, output columns will have the same
464                name as the input columns, i.e., the columns will be replaced).
465            column_order (Union[str, list[str]], optional): Specifies the list of all the columns you need in the whole
466                dataset. The parameter is required when len(input_column) != len(output_column). Caution: the list here
467                is not just the columns specified in parameter input_columns and output_columns.
468            pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
469                would pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0
470                (default=None).
471            python_multiprocessing (bool, optional): Parallelize Python function per_batch_map with multi-processing.
472                This option could be beneficial if the function is computational heavy (default=False).
473            max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy
474                data between processes.  This is only used if python_multiprocessing is set to True (default 16 MB).
475
476        Returns:
477            BatchDataset, dataset batched.
478
479        Examples:
480            >>> # Create a dataset where every 100 rows are combined into a batch
481            >>> # and drops the last incomplete batch if there is one.
482            >>> dataset = dataset.batch(100, True)
483            >>> # resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25)
484            >>> def np_resize(col, batchInfo):
485            ...     output = col.copy()
486            ...     s = (batchInfo.get_batch_num() + 1) ** 2
487            ...     index = 0
488            ...     for c in col:
489            ...         img = Image.fromarray(c.astype('uint8')).convert('RGB')
490            ...         img = img.resize((s, s), Image.ANTIALIAS)
491            ...         output[index] = np.array(img)
492            ...         index += 1
493            ...     return (output,)
494            >>> dataset = dataset.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize)
495        """
496        return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns,
497                            output_columns, column_order, pad_info, python_multiprocessing, max_rowsize)
498
499    @check_sync_wait
500    def sync_wait(self, condition_name, num_batch=1, callback=None):
501        """
502        Add a blocking condition to the input Dataset. A synchronize action will be applied.
503
504        Args:
505            condition_name (str): The condition name that is used to toggle sending next row.
506            num_batch (int): the number of batches without blocking at the start of each epoch.
507            callback (function): The callback function that will be invoked when sync_update is called.
508
509        Returns:
510            SyncWaitDataset, dataset added a blocking condition.
511
512        Raises:
513            RuntimeError: If condition name already exists.
514
515        Examples:
516            >>> import numpy as np
517            >>> def gen():
518            ...     for i in range(100):
519            ...         yield (np.array(i),)
520            >>>
521            >>> class Augment:
522            ...     def __init__(self, loss):
523            ...         self.loss = loss
524            ...
525            ...     def preprocess(self, input_):
526            ...         return input_
527            ...
528            ...     def update(self, data):
529            ...         self.loss = data["loss"]
530            >>>
531            >>> batch_size = 4
532            >>> dataset = ds.GeneratorDataset(gen, column_names=["input"])
533            >>>
534            >>> aug = Augment(0)
535            >>> dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
536            >>> dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
537            >>> dataset = dataset.batch(batch_size)
538            >>> count = 0
539            >>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
540            ...     assert data["input"][0] == count
541            ...     count += batch_size
542            ...     data = {"loss": count}
543            ...     dataset.sync_update(condition_name="policy", data=data)
544        """
545        return SyncWaitDataset(self, condition_name, num_batch, callback)
546
547    @check_shuffle
548    def shuffle(self, buffer_size):
549        """
550        Randomly shuffles the rows of this dataset using the following policy:
551
552        1. Make a shuffle buffer that contains the first buffer_size rows.
553        2. Randomly select an element from the shuffle buffer to be the next row
554           propagated to the child node.
555        3. Get the next row (if any) from the parent node and put it in the shuffle buffer.
556        4. Repeat steps 2 and 3 until there are no more rows left in the shuffle buffer.
557
558        A random seed can be provided to be used on the first epoch. In every subsequent
559        epoch, the seed is changed to a new one, randomly generated value.
560
561        Args:
562            buffer_size (int): The size of the buffer (must be larger than 1) for
563                shuffling. Setting buffer_size equal to the number of rows in the entire
564                dataset will result in a global shuffle.
565
566        Returns:
567            ShuffleDataset, dataset shuffled.
568
569        Raises:
570            RuntimeError: If exist sync operators before shuffle.
571
572        Examples:
573            >>> # dataset is an instance object of Dataset
574            >>> # Optionally set the seed for the first epoch
575            >>> ds.config.set_seed(58)
576            >>> # Create a shuffled dataset using a shuffle buffer of size 4
577            >>> dataset = dataset.shuffle(4)
578        """
579        return ShuffleDataset(self, buffer_size)
580
581    def flat_map(self, func):
582        """
583        Map `func` to each row in dataset and flatten the result.
584
585        The specified `func` is a function that must take one 'Ndarray' as input
586        and return a 'Dataset'.
587
588        Args:
589            func (function): A function that must take one 'Ndarray' as an argument and
590                return a 'Dataset'.
591
592        Returns:
593            Dataset, dataset applied by the function.
594
595        Examples:
596            >>> # use NumpySlicesDataset as an example
597            >>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]])
598            >>>
599            >>> def flat_map_func(array):
600            ...     # create a NumpySlicesDataset with the array
601            ...     dataset = ds.NumpySlicesDataset(array)
602            ...     # repeat the dataset twice
603            ...     dataset = dataset.repeat(2)
604            ...     return dataset
605            >>>
606            >>> dataset = dataset.flat_map(flat_map_func)
607            >>> # [[0, 1], [0, 1], [2, 3], [2, 3]]
608
609        Raises:
610            TypeError: If `func` is not a function.
611            TypeError: If `func` doesn't return a Dataset.
612        """
613        dataset = None
614        if not hasattr(func, '__call__'):
615            logger.error("func must be a function.")
616            raise TypeError("func must be a function.")
617
618        for row_data in self.create_tuple_iterator(output_numpy=True):
619            if dataset is None:
620                dataset = func(row_data)
621            else:
622                dataset += func(row_data)
623
624        if not isinstance(dataset, Dataset):
625            logger.error("flat_map must return a Dataset object.")
626            raise TypeError("flat_map must return a Dataset object.")
627        return dataset
628
629    @check_map
630    def map(self, operations, input_columns=None, output_columns=None, column_order=None,
631            num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16):
632        """
633        Apply each operation in operations to this dataset.
634
635        The order of operations is determined by the position of each operation in the operations parameter.
636        operations[0] will be applied first, then operations[1], then operations[2], etc.
637
638        Each operation will be passed one or more columns from the dataset as input, and zero or
639        more columns will be outputted. The first operation will be passed the columns specified
640        in input_columns as input. If there is more than one operator in operations, the outputted
641        columns of the previous operation are used as the input columns for the next operation.
642        The columns outputted by the very last operation will be assigned names specified by
643        output_columns.
644
645        Only the columns specified in column_order will be propagated to the child node. These
646        columns will be in the same order as specified in column_order.
647
648        Args:
649            operations (Union[list[TensorOp], list[functions]]): List of operations to be
650                applied on the dataset. Operations are applied in the order they appear in this list.
651            input_columns (Union[str, list[str]], optional): List of the names of the columns that will be passed to
652                the first operation as input. The size of this list must match the number of
653                input columns expected by the first operator. (default=None, the first
654                operation will be passed however many columns that are required, starting from
655                the first column).
656            output_columns (Union[str, list[str]], optional): List of names assigned to the columns outputted by
657                the last operation. This parameter is mandatory if len(input_columns) !=
658                len(output_columns). The size of this list must match the number of output
659                columns of the last operation. (default=None, output columns will have the same
660                name as the input columns, i.e., the columns will be replaced).
661            column_order (list[str], optional): Specifies the list of all the columns you need in the whole
662                dataset. The parameter is required when len(input_column) != len(output_column). Caution: the list here
663                is not just the columns specified in parameter input_columns and output_columns.
664            num_parallel_workers (int, optional): Number of threads used to process the dataset in
665                parallel (default=None, the value from the configuration will be used).
666            python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker processes. This
667                option could be beneficial if the Python operation is computational heavy (default=False).
668            cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
669                (default=None, which means no cache is used).
670            callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None).
671            max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy
672                data between processes.  This is only used if python_multiprocessing is set to True (default 16 MB).
673
674
675        Returns:
676            MapDataset, dataset after mapping operation.
677
678        Examples:
679            >>> # dataset is an instance of Dataset which has 2 columns, "image" and "label".
680            >>>
681            >>> # Define two operations, where each operation accepts 1 input column and outputs 1 column.
682            >>> decode_op = c_vision.Decode(rgb=True)
683            >>> random_jitter_op = c_vision.RandomColorAdjust(brightness=(0.8, 0.8), contrast=(1, 1),
684            ...                                               saturation=(1, 1), hue=(0, 0))
685            >>>
686            >>> # 1) Simple map example.
687            >>>
688            >>> # Apply decode_op on column "image". This column will be replaced by the outputted
689            >>> # column of decode_op. Since column_order is not provided, both columns "image"
690            >>> # and "label" will be propagated to the child node in their original order.
691            >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"])
692            >>>
693            >>> # Decode and rename column "image" to "decoded_image".
694            >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"], output_columns=["decoded_image"])
695            >>>
696            >>> # Specify the order of the output columns.
697            >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"],
698            ...                       output_columns=None, column_order=["label", "image"])
699            >>>
700            >>> # Rename column "image" to "decoded_image" and also specify the order of the output columns.
701            >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"],
702            ...                       output_columns=["decoded_image"], column_order=["label", "decoded_image"])
703            >>>
704            >>> # Rename column "image" to "decoded_image" and keep only this column.
705            >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"],
706            ...                       output_columns=["decoded_image"], column_order=["decoded_image"])
707            >>>
708            >>> # A simple example for mapping pyfunc. Renaming columns and specifying column order
709            >>> # work in the same way as the previous examples.
710            >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"])
711            >>> dataset = dataset.map(operations=[(lambda x: x + 1)], input_columns=["data"])
712            >>>
713            >>> # 2) Map example with more than one operation.
714            >>>
715            >>> # Create a dataset where the images are decoded, then randomly color jittered.
716            >>> # decode_op takes column "image" as input and outputs one column. The column
717            >>> # outputted by decode_op is passed as input to random_jitter_op.
718            >>> # random_jitter_op will output one column. Column "image" will be replaced by
719            >>> # the column outputted by random_jitter_op (the very last operation). All other
720            >>> # columns are unchanged. Since column_order is not specified, the order of the
721            >>> # columns will remain the same.
722            >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"])
723            >>>
724            >>> # Rename the column outputted by random_jitter_op to "image_mapped".
725            >>> # Specifying column order works in the same way as examples in 1).
726            >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"],
727            ...                       output_columns=["image_mapped"])
728            >>>
729            >>> # Map with multiple operations using pyfunc. Renaming columns and specifying column order
730            >>> # work in the same way as examples in 1).
731            >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"])
732            >>> dataset = dataset.map(operations=[(lambda x: x * x), (lambda x: x - 1)], input_columns=["data"],
733            ...                                   output_columns=["data_mapped"])
734            >>>
735            >>> # 3) Example where number of input columns is not equal to number of output columns.
736            >>>
737            >>> # operations[0] is a lambda that takes 2 columns as input and outputs 3 columns.
738            >>> # operations[1] is a lambda that takes 3 columns as input and outputs 1 column.
739            >>> # operations[2] is a lambda that takes 1 column as input and outputs 4 columns.
740            >>> #
741            >>> # Note: The number of output columns of operation[i] must equal the number of
742            >>> # input columns of operation[i+1]. Otherwise, this map call will also result
743            >>> # in an error.
744            >>> operations = [(lambda x, y: (x, x + y, x + y + 1)),
745            ...               (lambda x, y, z: x * y * z),
746            ...               (lambda x: (x % 2, x % 3, x % 5, x % 7))]
747            >>>
748            >>> # Note: Since the number of input columns is not the same as the number of
749            >>> # output columns, the output_columns and column_order parameters must be
750            >>> # specified. Otherwise, this map call will also result in an error.
751            >>>
752            >>> dataset = ds.NumpySlicesDataset(data=([[0, 1, 2]], [[3, 4, 5]]), column_names=["x", "y"])
753            >>>
754            >>> # Propagate all columns to the child node in this order:
755            >>> dataset = dataset.map(operations, input_columns=["x", "y"],
756            ...                       output_columns=["mod2", "mod3", "mod5", "mod7"],
757            ...                       column_order=["mod2", "mod3", "mod5", "mod7"])
758            >>>
759            >>> # Propagate some columns to the child node in this order:
760            >>> dataset = dataset.map(operations, input_columns=["x", "y"],
761            ...                       output_columns=["mod2", "mod3", "mod5", "mod7"],
762            ...                       column_order=["mod7", "mod3", "col2"])
763        """
764
765        return MapDataset(self, operations, input_columns, output_columns, column_order, num_parallel_workers,
766                          python_multiprocessing, cache, callbacks, max_rowsize)
767
768    @check_filter
769    def filter(self, predicate, input_columns=None, num_parallel_workers=None):
770        """
771        Filter dataset by prediction.
772
773        Note:
774             If input_columns not provided or provided with empty, all columns will be used.
775
776        Args:
777            predicate (callable): Python callable which returns a boolean value. If False then filter the element.
778            input_columns (Union[str, list[str]], optional): List of names of the input columns, when
779                default=None, the predicate will be applied on all columns in the dataset.
780            num_parallel_workers (int, optional): Number of workers to process the dataset
781                in parallel (default=None).
782
783        Returns:
784            FilterDataset, dataset filtered.
785
786        Examples:
787            >>> # generator data(0 ~ 63)
788            >>> # filter the data that greater than or equal to 11
789            >>> dataset = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"])
790        """
791        return FilterDataset(self, predicate, input_columns, num_parallel_workers)
792
793    @check_repeat
794    def repeat(self, count=None):
795        """
796        Repeat this dataset `count` times. Repeat infinitely if the count is None or -1.
797
798        Note:
799            The order of using repeat and batch reflects the number of batches. It is recommended that
800            the repeat operation is used after the batch operation.
801
802        Args:
803            count (int): Number of times the dataset is going to be repeated (default=None).
804
805        Returns:
806            RepeatDataset, dataset repeated.
807
808        Examples:
809            >>> # dataset is an instance object of Dataset
810            >>>
811            >>> # Create a dataset where the dataset is repeated for 50 epochs
812            >>> dataset = dataset.repeat(50)
813            >>>
814            >>> # Create a dataset where each epoch is shuffled individually
815            >>> dataset = dataset.shuffle(10)
816            >>> dataset = dataset.repeat(50)
817            >>>
818            >>> # Create a dataset where the dataset is first repeated for
819            >>> # 50 epochs before shuffling. The shuffle operator will treat
820            >>> # the entire 50 epochs as one big dataset.
821            >>> dataset = dataset.repeat(50)
822            >>> dataset = dataset.shuffle(10)
823        """
824        return RepeatDataset(self, count)
825
826    @check_skip
827    def skip(self, count):
828        """
829        Skip the first N elements of this dataset.
830
831        Args:
832            count (int): Number of elements in the dataset to be skipped.
833
834        Returns:
835            SkipDataset, dataset that containing rows like origin rows subtract skipped rows.
836
837        Examples:
838            >>> # dataset is an instance object of Dataset
839            >>> # Create a dataset which skips first 3 elements from data
840            >>> dataset = dataset.skip(3)
841        """
842        return SkipDataset(self, count)
843
844    @check_take
845    def take(self, count=-1):
846        """
847        Takes at most given numbers of elements from the dataset.
848
849        Note:
850            1. If count is greater than the number of elements in the dataset or equal to -1,
851               all the elements in dataset will be taken.
852            2. The order of using take and batch matters. If take is before batch operation,
853               then take given number of rows; otherwise take given number of batches.
854
855        Args:
856            count (int, optional): Number of elements to be taken from the dataset (default=-1).
857
858        Returns:
859            TakeDataset, dataset taken.
860
861        Examples:
862            >>> # dataset is an instance object of Dataset
863            >>> # Create a dataset where the dataset includes 50 elements.
864            >>> dataset = dataset.take(50)
865        """
866        return TakeDataset(self, count)
867
868    def _get_absolute_split_sizes(self, sizes):
869        """
870        Internal method called by split to calculate absolute split sizes and to
871        do some error checking after calculating absolute split sizes.
872
873        Returns:
874            int, absolute split sizes of the dataset.
875        """
876        # Call get_dataset_size here and check input here because
877        # don't want to call this once in check_split and another time in
878        # here again
879        dataset_size = self.get_dataset_size()
880
881        if dataset_size is None or dataset_size <= 0:
882            raise RuntimeError("dataset_size is unknown, unable to split.")
883
884        if not isinstance(sizes, list):
885            raise RuntimeError("sizes must be a list.")
886
887        all_int = all(isinstance(item, int) for item in sizes)
888        if all_int:
889            sizes_sum = sum(sizes)
890            if sizes_sum != dataset_size:
891                raise RuntimeError("Sum of split sizes {} is not equal to dataset size {}."
892                                   .format(sizes_sum, dataset_size))
893            return sizes
894
895        absolute_sizes = []
896        for item in sizes:
897            absolute_size = int(round(item * dataset_size))
898            if absolute_size == 0:
899                raise RuntimeError("Split percentage {} is too small.".format(item))
900            absolute_sizes.append(absolute_size)
901
902        absolute_sizes_sum = sum(absolute_sizes)
903
904        # if we still need more rows, give them to the first split.
905        # if we have too many rows, remove the extras from the first split that has
906        # enough rows.
907        size_difference = int(dataset_size - absolute_sizes_sum)
908        if size_difference > 0:
909            absolute_sizes[0] += size_difference
910        else:
911            for i, _ in enumerate(absolute_sizes):
912                if absolute_sizes[i] + size_difference > 0:
913                    absolute_sizes[i] += size_difference
914                    break
915
916        if sum(absolute_sizes) != dataset_size:
917            raise RuntimeError("Sum of calculated split sizes {} is not equal to dataset size {}."
918                               .format(absolute_sizes_sum, dataset_size))
919
920        return absolute_sizes
921
922    @check_split
923    def split(self, sizes, randomize=True):
924        """
925        Split the dataset into smaller, non-overlapping datasets.
926
927        This is a general purpose split function which can be called from any operator in the pipeline.
928        There is another, optimized split function, which will be called automatically if ds.split is
929        called where ds is a MappableDataset.
930
931        Args:
932            sizes (Union[list[int], list[float]]): If a list of integers [s1, s2, …, sn] is
933                provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
934                respectively. If the sum of all input sizes does not equal the original dataset size, an
935                error will throw.
936                If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1
937                and must sum to 1, otherwise an error will throw. The dataset will be split into n
938                Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the
939                original dataset.
940                If after rounding:
941
942                - Any size equals 0, an error will occur.
943                - The sum of split sizes < K, the difference of K - sigma(round(fi * k)) will be added to the first
944                  split.
945                - The sum of split sizes > K, the difference of sigma(round(fi * K)) - K will be removed from the first
946                  large enough split such that it will have at least 1 row after removing the difference.
947
948            randomize (bool, optional): Determines whether or not to split the data randomly (default=True).
949                If True, the data will be randomly split. Otherwise, each split will be created with
950                consecutive rows from the dataset.
951
952        Note:
953            1. Dataset cannot be sharded if split is going to be called.
954            2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
955               Shuffling the dataset may not be deterministic, which means the data in each split
956               will be different in each epoch.
957
958        Raises:
959            RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
960            RuntimeError: If `sizes` is list of integers and sum of all elements in sizes does not
961                equal the dataset size.
962            RuntimeError: If `sizes` is list of float and there is a split with size 0 after calculations.
963            RuntimeError: If the dataset is sharded prior to calling split.
964            ValueError: If `sizes` is list of float and not all floats are between 0 and 1, or if the
965                floats don't sum to 1.
966
967        Returns:
968            tuple(Dataset), a tuple of datasets that have been split.
969
970        Examples:
971            >>> # TextFileDataset is not a mappable dataset, so this non-optimized split will be called.
972            >>> # Since many datasets have shuffle on by default, set shuffle to False if split will be called!
973            >>> dataset = ds.TextFileDataset(text_file_dataset_dir, shuffle=False)
974            >>> train_dataset, test_dataset = dataset.split([0.9, 0.1])
975        """
976        if self.is_shuffled():
977            logger.warning("Dataset is shuffled before split.")
978
979        if self.is_sharded():
980            raise RuntimeError("Dataset should not be sharded before split.")
981
982        absolute_sizes = self._get_absolute_split_sizes(sizes)
983        splits = []
984        rows_to_skip = 0
985        for size in absolute_sizes:
986            ds = copy.deepcopy(self)
987            if randomize:
988                # want to shuffle the same way every epoch before split
989                # in alter_tree, shuffle buffer is minimum 10000, so use 10000 here
990                ds = ds.shuffle(10000)
991                ds.reshuffle_each_epoch = False
992
993            if rows_to_skip > 0:
994                ds = ds.skip(rows_to_skip)
995
996            ds = ds.take(size)
997            splits.append(ds)
998
999            rows_to_skip += size
1000
1001        return tuple(splits)
1002
1003    @check_zip_dataset
1004    def zip(self, datasets):
1005        """
1006        Zip the datasets in the sense of input tuple of datasets. Columns in the input datasets must have different
1007        name.
1008
1009        Args:
1010            datasets (Union[tuple, class Dataset]): A tuple of datasets or a single class Dataset
1011                to be zipped together with this dataset.
1012
1013        Returns:
1014            ZipDataset, dataset zipped.
1015
1016        Examples:
1017            >>> # Create a dataset which is the combination of dataset and dataset_1
1018            >>> dataset = dataset.zip(dataset_1)
1019        """
1020        if isinstance(datasets, tuple):
1021            datasets = (self, *datasets)
1022        elif isinstance(datasets, Dataset):
1023            datasets = (self, datasets)
1024        else:
1025            raise TypeError("Invalid datasets, expected Dataset object or tuple of Dataset, but got %s!" % datasets)
1026        return ZipDataset(datasets)
1027
1028    @check_concat
1029    def concat(self, datasets):
1030        """
1031        Concatenate the dataset objects in the input list.
1032        Performing "+" operation on dataset objects can achieve the same effect.
1033
1034        Note:
1035            The column name, and rank and type of the column data must be the same in the input datasets.
1036
1037        Args:
1038            datasets (Union[list, class Dataset]): A list of datasets or a single class Dataset
1039                to be concatenated together with this dataset.
1040
1041        Returns:
1042            ConcatDataset, dataset concatenated.
1043
1044        Examples:
1045            >>> # Create a dataset by concatenating dataset_1 and dataset_2 with "+" operator
1046            >>> dataset = dataset_1 + dataset_2
1047            >>> # Create a dataset by concatenating dataset_1 and dataset_2 with concat operation
1048            >>> dataset = dataset_1.concat(dataset_2)
1049        """
1050        if isinstance(datasets, Dataset):
1051            datasets = [self] + [datasets]
1052        elif isinstance(datasets, list):
1053            datasets = [self] + datasets
1054        else:
1055            raise TypeError("Invalid datasets, expected Dataset object or list of Dataset, but got %s!" % datasets)
1056        return ConcatDataset(datasets)
1057
1058    @check_rename
1059    def rename(self, input_columns, output_columns):
1060        """
1061        Rename the columns in input datasets.
1062
1063        Args:
1064            input_columns (Union[str, list[str]]): List of names of the input columns.
1065            output_columns (Union[str, list[str]]): List of names of the output columns.
1066
1067        Returns:
1068            RenameDataset, dataset renamed.
1069
1070        Examples:
1071            >>> # dataset is an instance object of Dataset
1072            >>> input_columns = ["input_col1", "input_col2", "input_col3"]
1073            >>> output_columns = ["output_col1", "output_col2", "output_col3"]
1074            >>>
1075            >>> # Create a dataset where input_col1 is renamed to output_col1, and
1076            >>> # input_col2 is renamed to output_col2, and input_col3 is renamed
1077            >>> # to output_col3.
1078            >>> dataset = dataset.rename(input_columns=input_columns, output_columns=output_columns)
1079        """
1080
1081        return RenameDataset(self, input_columns, output_columns)
1082
1083    @check_project
1084    def project(self, columns):
1085        """
1086        Project certain columns in input dataset.
1087
1088        The specified columns will be selected from the dataset and passed into
1089        the pipeline with the order specified. The other columns are discarded.
1090
1091        Args:
1092            columns(Union[str, list[str]]): List of names of the columns to project.
1093
1094        Returns:
1095            ProjectDataset, dataset projected.
1096
1097        Examples:
1098            >>> # dataset is an instance object of Dataset
1099            >>> columns_to_project = ["column3", "column1", "column2"]
1100            >>>
1101            >>> # Create a dataset that consists of column3, column1, column2
1102            >>> # in that order, regardless of the original order of columns.
1103            >>> dataset = dataset.project(columns=columns_to_project)
1104        """
1105
1106        return ProjectDataset(self, columns)
1107
1108    def build_vocab(self, columns, freq_range, top_k, special_tokens, special_first):
1109        """
1110        Function to create a Vocab from source dataset
1111
1112        Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab
1113        which contains top_k most frequent words (if top_k is specified)
1114
1115        Args:
1116
1117            columns(Union[str, list[str]]): Column names to get words from.
1118            freq_range(tuple[int]): A tuple of integers (min_frequency, max_frequency). Words within the frequency
1119                range will be stored.
1120                Naturally 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
1121                can be set to default, which corresponds to 0/total_words separately.
1122            top_k(int): Number of words to be built into vocab. top_k most frequent words are
1123                taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken
1124            special_tokens(list[str]): A list of strings, each one is a special token.
1125            special_first(bool): Whether special_tokens will be prepended/appended to vocab, If special_tokens
1126                is specified and special_first is set to default, special_tokens will be prepended.
1127
1128        Returns:
1129            Vocab, vocab built from the dataset.
1130
1131        Examples:
1132            >>> import numpy as np
1133            >>>
1134            >>> def gen_corpus():
1135            ...     # key: word, value: number of occurrences, reason for using letters is so their order is apparent
1136            ...     corpus = {"Z": 4, "Y": 4, "X": 4, "W": 3, "U": 3, "V": 2, "T": 1}
1137            ...     for k, v in corpus.items():
1138            ...         yield (np.array([k] * v, dtype='S'),)
1139            >>> column_names = ["column1"]
1140            >>> dataset = ds.GeneratorDataset(gen_corpus, column_names)
1141            >>> dataset = dataset.build_vocab(columns=["column1"],
1142            ...                               freq_range=(1, 10), top_k=5,
1143            ...                               special_tokens=["<pad>", "<unk>"],
1144            ...                               special_first=True)
1145
1146        """
1147        vocab = cde.Vocab()
1148        columns = replace_none(columns, [])
1149        if not isinstance(columns, list):
1150            columns = [columns]
1151
1152        freq_range = replace_none(freq_range, (0, 9223372036854775807))
1153        if freq_range[0] is None:
1154            freq_range = (0, freq_range[1])
1155        if freq_range[1] is None:
1156            freq_range = (freq_range[0], 9223372036854775807)
1157        special_tokens = replace_none(special_tokens, [])
1158        top_k = replace_none(top_k, 9223372036854775807)
1159
1160        ir_tree, api_tree = self.create_ir_tree()
1161
1162        # vocab node
1163        vocab_node = cde.BuildVocabNode(ir_tree, vocab, columns, freq_range, top_k, special_tokens, special_first)
1164
1165        runtime_context = cde.PythonRuntimeContext()
1166        runtime_context.Init()
1167
1168        # build vocab
1169        consumer = cde.PythonBuildVocabConsumer()
1170        consumer.Init(vocab_node)
1171        runtime_context.AssignConsumer(consumer)
1172
1173        consumer.Start()
1174        del api_tree
1175
1176        return vocab
1177
1178    def build_sentencepiece_vocab(self, columns, vocab_size, character_coverage, model_type, params):
1179        """
1180        Function to create a SentencePieceVocab from source dataset
1181
1182        Build a SentencePieceVocab from a dataset.
1183
1184        Args:
1185
1186            columns(list[str]): Column names to get words from.
1187            vocab_size(int): Vocabulary size.
1188            character_coverage(int): Percentage of characters covered by the model, must be between
1189                        0.98 and 1.0 Good defaults are: 0.9995 for languages with rich character sets like
1190                        Japanese or Chinese character sets, and 1.0 for other languages with small character sets
1191                        like English or Latin.
1192            model_type(SentencePieceModel): Model type. Choose from unigram (default), bpe, char, or word.
1193                                        The input sentence must be pretokenized when using word type.
1194            params(dict): Any extra optional parameters of sentencepiece library according to your raw data
1195
1196        Returns:
1197            SentencePieceVocab, vocab built from the dataset.
1198
1199        Examples:
1200            >>> from mindspore.dataset.text import SentencePieceModel
1201            >>>
1202            >>> def gen_corpus():
1203            ...     # key: word, value: number of occurrences, reason for using letters is so their order is apparent
1204            ...     corpus = {"Z": 4, "Y": 4, "X": 4, "W": 3, "U": 3, "V": 2, "T": 1}
1205            ...     for k, v in corpus.items():
1206            ...         yield (np.array([k] * v, dtype='S'),)
1207            >>> column_names = ["column1","column2","column3"]
1208            >>> dataset = ds.GeneratorDataset(gen_corpus, column_names)
1209            >>> dataset = dataset.build_sentencepiece_vocab(columns=["column3", "column1", "column2"],
1210            ...                                             vocab_size=5000,
1211            ...                                             character_coverage=0.9995,
1212            ...                                             model_type=SentencePieceModel.UNIGRAM,
1213            ...                                             params={})
1214        """
1215        vocab = cde.SentencePieceVocab()
1216
1217        ir_tree, api_tree = self.create_ir_tree()
1218
1219        # vocab node
1220        vocab_node = cde.BuildSentenceVocabNode(ir_tree, vocab, columns, vocab_size, character_coverage, model_type,
1221                                                params)
1222
1223        runtime_context = cde.PythonRuntimeContext()
1224        runtime_context.Init()
1225
1226        # build vocab
1227        consumer = cde.PythonBuildVocabConsumer()
1228        consumer.Init(vocab_node)
1229        runtime_context.AssignConsumer(consumer)
1230
1231        consumer.Start()
1232        del api_tree
1233
1234        return vocab
1235
1236    def apply(self, apply_func):
1237        """
1238        Apply a function in this dataset.
1239
1240        Args:
1241            apply_func (function): A function that must take one 'Dataset' as an argument and
1242                                   return a preprocessed 'Dataset'.
1243
1244        Returns:
1245            Dataset, dataset applied by the function.
1246
1247        Examples:
1248            >>> # dataset is an instance object of Dataset
1249            >>>
1250            >>> # Declare an apply_func function which returns a Dataset object
1251            >>> def apply_func(data):
1252            ...     data = data.batch(2)
1253            ...     return data
1254            >>>
1255            >>> # Use apply to call apply_func
1256            >>> dataset = dataset.apply(apply_func)
1257
1258        Raises:
1259            TypeError: If apply_func is not a function.
1260            TypeError: If apply_func doesn't return a Dataset.
1261        """
1262
1263        if not hasattr(apply_func, '__call__'):
1264            raise TypeError("apply_func must be a function.")
1265
1266        dataset = apply_func(self)
1267        if not isinstance(dataset, Dataset):
1268            raise TypeError("apply_func must return a dataset.")
1269        return dataset
1270
1271    @check_device_send
1272    def device_que(self, send_epoch_end=True, create_data_info_queue=False):
1273        """
1274        Return a transferred Dataset that transfers data through a device.
1275
1276        Args:
1277            send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
1278            create_data_info_queue (bool, optional): Whether to create queue which stores
1279                types and shapes of data or not(default=False).
1280
1281        Note:
1282            If device is Ascend, features of data will be transferred one by one. The limitation
1283            of data transmission per time is 256M.
1284
1285        Returns:
1286            TransferDataset, dataset for transferring.
1287        """
1288        return self.to_device(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue)
1289
1290    @check_device_send
1291    def to_device(self, send_epoch_end=True, create_data_info_queue=False):
1292        """
1293        Transfer data from CPU to GPU or Ascend or other devices.
1294
1295        Args:
1296            send_epoch_end (bool, optional): Whether to send the end of sequence to device or not (default=True).
1297            create_data_info_queue (bool, optional): Whether to create queue which stores
1298                types and shapes of data or not(default=False).
1299
1300        Note:
1301            If device is Ascend, features of data will be transferred one by one. The limitation
1302            of data transmission per second is 256M.
1303
1304        Returns:
1305            TransferDataset, dataset for transferring.
1306
1307        Raises:
1308            RuntimeError: If distribution file path is given but failed to read.
1309        """
1310        return TransferDataset(self, send_epoch_end, create_data_info_queue)
1311
1312    @check_save
1313    def save(self, file_name, num_files=1, file_type='mindrecord'):
1314        """
1315        Save the dynamic data processed by the dataset pipeline in common dataset format.
1316        Supported dataset formats: 'mindrecord' only
1317
1318        Implicit type casting exists when saving data as 'mindrecord'. The transform table shows how to do type casting.
1319
1320        .. list-table:: Implicit Type Casting when Saving as 'mindrecord'
1321           :widths: 25 25 50
1322           :header-rows: 1
1323
1324           * - Type in 'dataset'
1325             - Type in 'mindrecord'
1326             - Details
1327           * - bool
1328             - None
1329             - Not supported
1330           * - int8
1331             - int32
1332             -
1333           * - uint8
1334             - bytes(1D uint8)
1335             - Drop dimension
1336           * - int16
1337             - int32
1338             -
1339           * - uint16
1340             - int32
1341             -
1342           * - int32
1343             - int32
1344             -
1345           * - uint32
1346             - int64
1347             -
1348           * - int64
1349             - int64
1350             -
1351           * - uint64
1352             - None
1353             - Not supported
1354           * - float16
1355             - float32
1356             -
1357           * - float32
1358             - float32
1359             -
1360           * - float64
1361             - float64
1362             -
1363           * - string
1364             - string
1365             - Multi-dimensional string not supported
1366
1367        Note:
1368            1. To save the samples in order, set dataset's shuffle to False and num_files to 1.
1369            2. Before calling the function, do not use batch operator, repeat operator or data augmentation operators
1370               with random attribute in map operator.
1371            3. When array dimension is variable, one-dimensional arrays or
1372               multi-dimensional arrays with variable dimension 0 are supported.
1373            4. Mindrecord does not support DE_UINT64, multi-dimensional DE_UINT8(drop dimension) nor
1374               multi-dimensional DE_STRING.
1375
1376        Args:
1377            file_name (str): Path to dataset file.
1378            num_files (int, optional): Number of dataset files (default=1).
1379            file_type (str, optional): Dataset format (default='mindrecord').
1380
1381        """
1382        ir_tree, api_tree = self.create_ir_tree()
1383
1384        runtime_context = cde.PythonRuntimeContext()
1385        runtime_context.Init()
1386        consumer = cde.PythonSaveToDisk(file_name, num_files, file_type)
1387        consumer.Init(ir_tree)
1388        runtime_context.AssignConsumer(consumer)
1389
1390        consumer.Save()
1391        _set_dataset_permissions(file_name, num_files)
1392        del api_tree
1393
1394    @check_tuple_iterator
1395    def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True):
1396        """
1397        Create an iterator over the dataset. The datatype retrieved back will be a list of ndarrays.
1398
1399        To specify which columns to list and the order needed, use columns_list. If columns_list
1400        is not provided, the order of the columns will remain unchanged.
1401
1402        Args:
1403            columns (list[str], optional): List of columns to be used to specify the order of columns
1404                (default=None, means all columns).
1405            num_epochs (int, optional): Maximum number of epochs that iterator can be iterated.
1406                (default=-1, iterator can be iterated infinite number of epochs)
1407            output_numpy (bool, optional): Whether or not to output NumPy datatype.
1408                If output_numpy=False, iterator will output MSTensor (default=False).
1409            do_copy (bool, optional): when output data type is mindspore.Tensor,
1410                use this param to select the conversion method, only take False for better performance (default=True).
1411
1412        Returns:
1413            TupleIterator, tuple iterator over the dataset.
1414
1415        Examples:
1416            >>> # dataset is an instance object of Dataset
1417            >>> iterator = dataset.create_tuple_iterator()
1418            >>> for item in iterator:
1419            ...     # item is a list
1420            ...     print(type(item))
1421            ...     break
1422            <class 'list'>
1423        """
1424        if output_numpy is None:
1425            output_numpy = False
1426
1427        if Dataset._noop_mode():
1428            return DummyIterator(self, 'tuple')
1429        return TupleIterator(self, columns, num_epochs, output_numpy, do_copy)
1430
1431    @check_dict_iterator
1432    def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
1433        """
1434        Create an iterator over the dataset. The data retrieved will be a dictionary datatype.
1435
1436        The order of the columns in the dictionary may not be the same as the original order.
1437
1438        Args:
1439            num_epochs (int, optional): Maximum number of epochs that iterator can be iterated
1440                (default=-1, iterator can be iterated infinite number of epochs).
1441            output_numpy (bool, optional): Whether or not to output NumPy datatype,
1442                if output_numpy=False, iterator will output MSTensor (default=False).
1443
1444        Returns:
1445            DictIterator, dictionary iterator over the dataset.
1446
1447        Examples:
1448            >>> # dataset is an instance object of Dataset
1449            >>> iterator = dataset.create_dict_iterator()
1450            >>> for item in iterator:
1451            ...     # item is a dict
1452            ...     print(type(item))
1453            ...     break
1454            <class 'dict'>
1455        """
1456        if output_numpy is None:
1457            output_numpy = False
1458
1459        if Dataset._noop_mode():
1460            return DummyIterator(self, 'dict')
1461        return DictIterator(self, num_epochs, output_numpy)
1462
1463    def __iter__(self):
1464        """Create an iterator over the dataset."""
1465        return self.create_tuple_iterator(num_epochs=1)
1466
1467    @property
1468    def input_indexs(self):
1469        """
1470        Get Input Index Information
1471
1472        Returns:
1473            tuple, tuple of the input index information.
1474
1475        Examples:
1476            >>> # dataset is an instance object of Dataset
1477            >>> # set input_indexs
1478            >>> dataset.input_indexs = 10
1479            >>> print(dataset.input_indexs)
1480            10
1481        """
1482        if self._input_indexs != ():
1483            return self._input_indexs
1484
1485        # find input_indexes of children
1486        children_input_index = [child.input_indexs for child in self.children]
1487
1488        # in case of more than one child, return the first input_indexes
1489        for cix in children_input_index:
1490            if cix != ():
1491                return cix
1492
1493        # if all children's input_indexes are () or the node is a leaf
1494        return self._input_indexs
1495
1496    @input_indexs.setter
1497    def input_indexs(self, value):
1498        self._input_indexs = value
1499
1500    def copy_batch_size(self, value):
1501        self._batch_size = value
1502
1503    def _init_tree_getters(self):
1504        """
1505        Get pipeline information.
1506        """
1507        ir_tree, api_tree = self.create_ir_tree()
1508
1509        runtime_context = cde.PythonRuntimeContext()
1510        runtime_context.Init()
1511        getter = cde.TreeGetters()
1512        getter.Init(ir_tree)
1513        runtime_context.AssignConsumer(getter)
1514        return getter, runtime_context, api_tree
1515
1516    def __init_size_getter(self):
1517        """
1518        Get pipeline information.
1519        """
1520        ir_tree, api_tree = self.create_ir_tree()
1521
1522        runtime_context = cde.PythonRuntimeContext()
1523        runtime_context.Init()
1524        getter = cde.DatasetSizeGetters()
1525        getter.Init(ir_tree)
1526        runtime_context.AssignConsumer(getter)
1527        return getter, runtime_context, api_tree
1528
1529    def get_col_names(self):
1530        """
1531        Return the names of the columns in dataset.
1532
1533        Returns:
1534            list, list of column names in the dataset.
1535
1536        Examples:
1537            >>> # dataset is an instance object of Dataset
1538            >>> col_names = dataset.get_col_names()
1539        """
1540        if self._col_names is None:
1541            runtime_getter = self._init_tree_getters()
1542            self._col_names = runtime_getter[0].GetColumnNames()
1543            self.close_pool()
1544            runtime_getter[2].notify_watchdog()
1545        return self._col_names
1546
1547    def output_shapes(self):
1548        """
1549        Get the shapes of output data.
1550
1551        Returns:
1552            list, list of shapes of each column.
1553
1554        Examples:
1555            >>> # dataset is an instance object of Dataset
1556            >>> output_shapes = dataset.output_shapes()
1557        """
1558        if self.saved_output_shapes is None:
1559            runtime_getter = self._init_tree_getters()
1560            self.saved_output_shapes = runtime_getter[0].GetOutputShapes()
1561            self.saved_output_types = runtime_getter[0].GetOutputTypes()
1562            self.close_pool()
1563            runtime_getter[2].notify_watchdog()
1564        if self.dynamic_setting[0]:
1565            self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes = self._dynamic_output_shapes()
1566        return self.saved_output_shapes
1567
1568    def output_types(self):
1569        """
1570        Get the types of output data.
1571
1572        Returns:
1573            list, list of data types.
1574
1575        Examples:
1576            >>> # dataset is an instance object of Dataset
1577            >>> output_types = dataset.output_types()
1578        """
1579        if self.saved_output_types is None:
1580            runtime_getter = self._init_tree_getters()
1581            self.saved_output_shapes = runtime_getter[0].GetOutputShapes()
1582            self.saved_output_types = runtime_getter[0].GetOutputTypes()
1583            self.close_pool()
1584            runtime_getter[2].notify_watchdog()
1585        if self.dynamic_setting[0]:
1586            self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes = self._dynamic_output_shapes()
1587        return self.saved_output_types
1588
1589    def get_dataset_size(self):
1590        """
1591        Return the number of batches in an epoch.
1592
1593        Returns:
1594            int, number of batches.
1595
1596        Examples:
1597            >>> # dataset is an instance object of Dataset
1598            >>> dataset_size = dataset.get_dataset_size()
1599        """
1600        if self.dataset_size is None:
1601            runtime_getter = self.__init_size_getter()
1602            self.dataset_size = runtime_getter[0].GetDatasetSize(False)
1603            self.close_pool()
1604            runtime_getter[2].notify_watchdog()
1605        return self.dataset_size
1606
1607    def set_dynamic_columns(self, columns=None):
1608        """
1609        Set dynamic shape information of source data, it should be set after the pipeline is defined.
1610
1611        Args:
1612            columns (dict): A dict contains shape information of each column in dataset.
1613                The value of shape[i] is :py:obj:`None` indicates that the data length of shape[i] is dynamic.
1614
1615        Examples:
1616            >>> import numpy as np
1617            >>>
1618            >>> def generator1():
1619            >>>     for i in range(1, 100):
1620            >>>         yield np.ones((16, i, 83)), np.array(i)
1621            >>>
1622            >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
1623            >>> dataset.set_dynamic_columns(columns={"data1": [16, None, 83], "data2": []})
1624        """
1625        if not isinstance(columns, dict):
1626            raise TypeError("Pass a dict to set dynamic shape, example: {\"data1\": [16, None, 256]}")
1627        self.dynamic_setting[0] = True
1628        self.dynamic_setting[1] = columns
1629
1630    def dynamic_min_max_shapes(self):
1631        """
1632        Get minimum and maximum data length of dynamic source data, for dynamic graph compilation.
1633
1634        Returns:
1635            lists, min_shapes, max_shapes of source data.
1636
1637        Examples:
1638            >>> import numpy as np
1639            >>>
1640            >>> def generator1():
1641            >>>     for i in range(1, 100):
1642            >>>         yield np.ones((16, i, 83)), np.array(i)
1643            >>>
1644            >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
1645            >>> dataset.set_dynamic_columns(columns={"data1": [16, None, 83], "data2": []})
1646            >>> min_shapes, max_shapes = dataset.dynamic_min_max_shapes()
1647        """
1648        if self.saved_min_shapes is None or self.saved_max_shapes is None:
1649            self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes = self._dynamic_output_shapes()
1650        return self.saved_min_shapes, self.saved_max_shapes
1651
1652    def _dynamic_output_shapes(self):
1653        """
1654        Get dynamic information of source data.
1655
1656        Returns:
1657            lists, dynamic_shapes, min_shapes, max_shapes of source data.
1658        """
1659        if not self.dynamic_setting[1]:
1660            raise RuntimeError("dynamic_columns is not set, call set_dynamic_columns() by final Dataset Op.")
1661
1662        if self.saved_output_shapes is not None and self.saved_min_shapes is not None and \
1663                self.saved_max_shapes is not None:
1664            return self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes
1665
1666        logger.warning("Calculating dynamic shape of input data, this will take a few minutes...")
1667        # Assume data1 shape is dynamic, data2 shape is fix
1668        # {"data1": [batch_size, None, feat_len], "data2": [batch_size, feat_len]}
1669        dynamic_columns = self.dynamic_setting[1]
1670        # ["data1", "data2"]
1671        dataset_columns = self.get_col_names()
1672        for column in dynamic_columns:
1673            if column not in dataset_columns:
1674                raise RuntimeError("dynamic column [" + column + "] does not match any column in dataset: " +
1675                                   str(dataset_columns))
1676
1677        # Shape[1] of data1 is variable
1678        # {"data1": {(batch_size, 100, feat_len), (16, 200, 83)}, "data2": {(batch_size, feat_len)}}
1679        column_shape_set = {col: set() for col in dataset_columns}
1680        dataset_size_counter = 0
1681        for data in self.create_dict_iterator(num_epochs=1, output_numpy=True):
1682            dataset_size_counter += 1
1683            for col in data.keys():
1684                if col in dynamic_columns:
1685                    shape_mismatch = "dynamic column [" + col + "] with shape " + str(dynamic_columns[col]) + \
1686                    " does not match dataset column [" + col + "] with shape " + str(list(data[col].shape))
1687                    if data[col].ndim != len(dynamic_columns[col]):
1688                        raise RuntimeError(shape_mismatch)
1689                    for dim in range(len(dynamic_columns[col])):
1690                        if dynamic_columns[col][dim] is not None and dynamic_columns[col][dim] != data[col].shape[dim]:
1691                            raise RuntimeError(shape_mismatch)
1692                column_shape_set[col].add(tuple(data[col].shape))
1693
1694        # we get dataset_size after dryrun
1695        self.dataset_size = dataset_size_counter
1696
1697        min_shapes, max_shapes, dynamic_shapes = list(), list(), list()
1698        for col, shape_set in column_shape_set.items():
1699            if len(shape_set) > 1:
1700                if col not in dynamic_columns:
1701                    raise RuntimeError("column [" + col + "] has dynamic shape but not set by set_dynamic_columns()" +
1702                                       ", shapes of [" + col + "]: " + str(list(shape_set)))
1703                shape_npy = np.array(list(shape_set))
1704                max_shape = shape_npy.max(axis=0)
1705                min_shape = shape_npy.min(axis=0)
1706
1707                # Set min shape to 1 due to unknown shuffle
1708                min_shape = np.where(np.equal(dynamic_columns[col], None), 1, min_shape)
1709                # Set dynamic dim to -1 for ME
1710                dynamic_shape = np.where(np.equal(dynamic_columns[col], None), -1, dynamic_columns[col])
1711
1712                max_shapes.append(max_shape.tolist())
1713                min_shapes.append(min_shape.tolist())
1714                dynamic_shapes.append(dynamic_shape.tolist())
1715            else:
1716                # Also append fix shape to keep order of column shape
1717                fix_shape = list(list(shape_set)[0])
1718                max_shapes.append(fix_shape)
1719                min_shapes.append(fix_shape)
1720                dynamic_shapes.append(fix_shape)
1721                if col in dynamic_columns:
1722                    logger.warning("column [" + col + "] has no dynamic shape but set by set_dynamic_columns()")
1723                    # Set min shape to 1 due to unknown shuffle
1724                    min_shapes[-1] = np.where(np.equal(dynamic_columns[col], None), 1, fix_shape).tolist()
1725                    # Set dynamic dim to -1 for ME
1726                    dynamic_shapes[-1] = np.where(np.equal(dynamic_columns[col], None), -1, fix_shape).tolist()
1727        return dynamic_shapes, min_shapes, max_shapes
1728
1729    def num_classes(self):
1730        """
1731        Get the number of classes in a dataset.
1732
1733        Returns:
1734            int, number of classes.
1735
1736        Examples:
1737            >>> # dataset is an instance object of Dataset
1738            >>> num_classes = dataset.num_classes()
1739        """
1740        if self._num_classes is None:
1741            runtime_getter = self._init_tree_getters()
1742            self._num_classes = runtime_getter[0].GetNumClasses()
1743            self.close_pool()
1744            runtime_getter[2].notify_watchdog()
1745        if self._num_classes == -1:
1746            return None
1747        return self._num_classes
1748
1749    def get_sync_notifiers(self):
1750        if self.children:
1751            return self.children[0].get_sync_notifiers()
1752        return {}
1753
1754    def disable_sync(self):
1755        if self.children:
1756            return self.children[0].disable_sync()
1757        return {}
1758
1759    def is_sync(self):
1760        if self.children:
1761            return self.children[0].is_sync()
1762        return False
1763
1764    def sync_update(self, condition_name, num_batch=None, data=None):
1765        """
1766        Release a blocking condition and trigger callback with given data.
1767
1768        Args:
1769            condition_name (str): The condition name that is used to toggle sending next row.
1770            num_batch (Union[int, None]): The number of batches (rows) that are released.
1771                When num_batch is None, it will default to the number specified by the
1772                sync_wait operator (default=None).
1773            data (Any): The data passed to the callback, user defined (default=None).
1774        """
1775        if (not isinstance(num_batch, int) and num_batch is not None) or \
1776                (isinstance(num_batch, int) and num_batch <= 0):
1777            # throwing exception, disable all sync_wait in pipeline
1778            self.disable_sync()
1779            raise RuntimeError("Sync_update batch size can only be positive integer, got : {}.".format(num_batch))
1780        notifiers_dict = self.get_sync_notifiers()
1781        if not isinstance(condition_name, str):
1782            raise TypeError("Argument condition_name with value {} is not of type str, but got {}."
1783                            .format(condition_name, type(condition_name)))
1784        if condition_name not in notifiers_dict:
1785            # throwing exception, disable all sync_wait in pipeline
1786            self.disable_sync()
1787            raise RuntimeError("Condition name not found.")
1788        if num_batch is not None:
1789            num_batch *= self.get_batch_size()
1790        notifiers_dict[condition_name](num_batch, data)
1791
1792    def get_batch_size(self):
1793        """
1794        Return the size of batch.
1795
1796        Returns:
1797            int, the number of data in a batch.
1798
1799        Examples:
1800            >>> # dataset is an instance object of Dataset
1801            >>> batch_size = dataset.get_batch_size()
1802        """
1803        if self._batch_size is None:
1804            runtime_getter = self._init_tree_getters()
1805            self._batch_size = runtime_getter[0].GetBatchSize()
1806        if self._batch_size is None:
1807            self._batch_size = 1
1808        return self._batch_size
1809
1810    def get_repeat_count(self):
1811        """
1812        Get the replication times in RepeatDataset (default is 1).
1813
1814        Returns:
1815            int, the count of repeat.
1816
1817        Examples:
1818            >>> # dataset is an instance object of Dataset
1819            >>> repeat_count = dataset.get_repeat_count()
1820        """
1821        if self._repeat_count is None:
1822            runtime_getter = self._init_tree_getters()
1823            self._repeat_count = runtime_getter[0].GetRepeatCount()
1824        if self._repeat_count is None:
1825            self._repeat_count = 1
1826        return self._repeat_count
1827
1828    def get_class_indexing(self):
1829        """
1830        Return the class index.
1831
1832        Returns:
1833            dict, a str-to-int mapping from label name to index.
1834            dict, a str-to-list<int> mapping from label name to index for Coco ONLY. The second number
1835            in the list is used to indicate the super category.
1836
1837        Examples:
1838            >>> # dataset is an instance object of Dataset
1839            >>> class_indexing = dataset.get_class_indexing()
1840        """
1841        if self.children:
1842            return self.children[0].get_class_indexing()
1843        return {}
1844
1845    def reset(self):
1846        """Reset the dataset for next epoch."""
1847
1848    def is_shuffled(self):
1849        """Returns True if the dataset or its children is shuffled."""
1850        for input_dataset in self.children:
1851            if input_dataset.is_shuffled():
1852                return True
1853
1854        return False
1855
1856    def is_sharded(self):
1857        """Returns True if the dataset or its children is sharded."""
1858        for input_dataset in self.children:
1859            if input_dataset.is_sharded():
1860                return True
1861
1862        return False
1863
1864    def parse(self, children=None):
1865        raise NotImplementedError("Dataset has to implement parse method.")
1866
1867    def post_parse(self, ir_node):
1868        if self.cache:
1869            ir_node = ir_node.set_cache_client(self.cache.cache_client)
1870        if self.num_parallel_workers:
1871            ir_node = ir_node.set_num_workers(self.num_parallel_workers)
1872
1873        return ir_node
1874
1875
1876class SourceDataset(Dataset):
1877    """
1878    Abstract class to represent a source dataset which produces content to the data pipeline.
1879    """
1880
1881    def __init__(self, num_parallel_workers=None, num_samples=None, shuffle=True, num_shards=None, shard_id=None,
1882                 cache=None):
1883        super().__init__(num_parallel_workers=num_parallel_workers, cache=cache)
1884        self.num_samples = replace_none(num_samples, 0)
1885        self.num_shards = replace_none(num_shards, 1)
1886        self.shard_id = replace_none(shard_id, 0)
1887
1888        if shuffle is not None and not isinstance(shuffle, (bool, Shuffle)):
1889            raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like 'Shuffle.GLOBAL' or "
1890                            "'Shuffle.FILES' or 'Shuffle.INFILE'.")
1891
1892        self.shuffle_flag = 2  # Global shuffle
1893        if not isinstance(shuffle, Shuffle):
1894            if shuffle is None or shuffle:
1895                self.shuffle_flag = 2  # Global shuffle
1896            else:
1897                self.shuffle_flag = 0  # No shuffle
1898        else:
1899            if shuffle == Shuffle.GLOBAL:
1900                self.shuffle_flag = 2  # Global shuffle
1901            elif shuffle == Shuffle.FILES:
1902                self.shuffle_flag = 1  # Files shuffle
1903            elif shuffle == Shuffle.INFILE:
1904                self.shuffle_flag = 3  # Infile shuffle
1905
1906    def parse(self, children=None):
1907        raise NotImplementedError("Dataset has to implement parse method.")
1908
1909    @staticmethod
1910    def _find_files(patterns):
1911        """
1912        Utility function to search for files with the given glob patterns.
1913
1914        Args:
1915            patterns (Union[str, list[str]]): String or list of patterns to be searched.
1916
1917        Returns:
1918            list, list of files.
1919        """
1920
1921        if not isinstance(patterns, list):
1922            patterns = [patterns]
1923
1924        file_list = []
1925        unmatched_patterns = []
1926        for pattern in patterns:
1927            matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)]
1928
1929            if matches:
1930                file_list.extend(matches)
1931            else:
1932                unmatched_patterns.append(pattern)
1933
1934        if unmatched_patterns:
1935            raise ValueError("The following patterns did not match any files: {}.".format(unmatched_patterns))
1936
1937        if file_list:  # not empty
1938            return file_list
1939        raise ValueError("The list of path names matching the patterns is empty.")
1940
1941    def is_shuffled(self):
1942        return self.shuffle_flag > 0
1943
1944    def is_sharded(self):
1945        if self.num_shards is not None:
1946            return self.num_shards > 1
1947        return False
1948
1949
1950class MappableDataset(SourceDataset):
1951    """
1952    Abstract class to represent a source dataset which supports use of samplers.
1953    """
1954
1955    def parse(self, children=None):
1956        raise NotImplementedError("Dataset has to implement parse method.")
1957
1958    def __init__(self, num_parallel_workers=None, sampler=None, num_samples=None, shuffle=None, num_shards=None,
1959                 shard_id=None, cache=None):
1960        super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
1961                         num_shards=num_shards, shard_id=shard_id, cache=cache)
1962        self.shuffle_flag = replace_none(shuffle, True)
1963        self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
1964
1965    def add_sampler(self, new_sampler):
1966        """
1967        Add a sampler for current dataset,.
1968
1969        Args:
1970            new_sampler (Sampler): The sampler to be added as the parent sampler for current dataset.
1971
1972        Examples:
1973            >>> # dataset is an instance object of Dataset
1974            >>> # use a DistributedSampler instead
1975            >>> new_sampler = ds.DistributedSampler(10, 2)
1976            >>> dataset.add_sampler(new_sampler)
1977        """
1978        # note: By adding a sampler, the sampled IDs will flow to new_sampler
1979        # after first passing through the current samplers attached to this dataset.
1980        self.dataset_size = None
1981        new_sampler.add_child(self.sampler)
1982        self.sampler = new_sampler
1983
1984    def use_sampler(self, new_sampler):
1985        """
1986        Make the current dataset use the new_sampler provided by other API.
1987
1988        Args:
1989            new_sampler (Sampler): The sampler to use for the current dataset.
1990
1991        Examples:
1992            >>> # dataset is an instance object of Dataset
1993            >>> # use a DistributedSampler instead
1994            >>> new_sampler = ds.DistributedSampler(10, 2)
1995            >>> dataset.use_sampler(new_sampler)
1996        """
1997        if new_sampler is None:
1998            raise TypeError("Input sampler can not be None.")
1999        if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
2000            raise TypeError("Input sampler is not an instance of a sampler.")
2001        self.dataset_size = None
2002
2003        self.sampler = self.sampler.child_sampler
2004        self.add_sampler(new_sampler)
2005
2006    def is_shuffled(self):
2007        return self.sampler.is_shuffled()
2008
2009    def is_sharded(self):
2010        return self.sampler.is_sharded()
2011
2012    @check_split
2013    def split(self, sizes, randomize=True):
2014        """
2015        Split the dataset into smaller, non-overlapping datasets.
2016
2017        Args:
2018            sizes (Union[list[int], list[float]]): If a list of integers [s1, s2, …, sn] is
2019                provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
2020                respectively. If the sum of all sizes does not equal the original dataset size, an
2021                error will occur.
2022                If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1
2023                and must sum to 1, otherwise an error will occur. The dataset will be split into n
2024                Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the
2025                original dataset.
2026                If after rounding:
2027
2028                - Any size equals 0, an error will occur.
2029                - The sum of split sizes < K, the difference will be added to the first split.
2030                - The sum of split sizes > K, the difference will be removed from the first large
2031                  enough split such that it will have at least 1 row after removing the difference.
2032
2033            randomize (bool, optional): Determines whether or not to split the data randomly (default=True).
2034                If True, the data will be randomly split. Otherwise, each split will be created with
2035                consecutive rows from the dataset.
2036
2037        Note:
2038            1. There is an optimized split function, which will be called automatically when the dataset
2039               that calls this function is a MappableDataset.
2040            2. Dataset should not be sharded if split is going to be called. Instead, create a
2041               DistributedSampler and specify a split to shard after splitting. If the dataset is
2042               sharded after a split, it is strongly recommended setting the same seed in each instance
2043               of execution, otherwise each shard may not be part of the same split (see Examples).
2044            3. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
2045               Shuffling the dataset may not be deterministic, which means the data in each split
2046               will be different in each epoch. Furthermore, if sharding occurs after split, each
2047               shard may not be part of the same split.
2048
2049        Raises:
2050            RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
2051            RuntimeError: If `sizes` is list of integers and sum of all elements in sizes does not
2052                equal the dataset size.
2053            RuntimeError: If `sizes` is list of float and there is a split with size 0 after calculations.
2054            RuntimeError: If the dataset is sharded prior to calling split.
2055            ValueError: If `sizes` is list of float and not all floats are between 0 and 1, or if the
2056                floats don't sum to 1.
2057
2058        Returns:
2059            tuple(Dataset), a tuple of datasets that have been split.
2060
2061        Examples:
2062            >>> # Since many datasets have shuffle on by default, set shuffle to False if split will be called!
2063            >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, shuffle=False)
2064            >>>
2065            >>> # Set the seed, and tell split to use this seed when randomizing.
2066            >>> # This is needed because sharding will be done later
2067            >>> ds.config.set_seed(58)
2068            >>> train_dataset, test_dataset = dataset.split([0.9, 0.1])
2069            >>>
2070            >>> # To shard the train dataset, use a DistributedSampler
2071            >>> train_sampler = ds.DistributedSampler(10, 2)
2072            >>> train_dataset.use_sampler(train_sampler)
2073        """
2074        if self.is_shuffled():
2075            logger.warning("Dataset is shuffled before split.")
2076
2077        if self.is_sharded():
2078            raise RuntimeError("Dataset should not be sharded before split.")
2079
2080        absolute_sizes = self._get_absolute_split_sizes(sizes)
2081        splits = []
2082        current_split_start_index = 0
2083        for size in absolute_sizes:
2084            ds = copy.deepcopy(self)
2085            ds.dataset_size = None
2086            if randomize:
2087                # want to shuffle the same way every epoch before split, we are assuming
2088                # that the user will call set_seed
2089                random_sampler = samplers.RandomSampler()
2090                random_sampler.reshuffle_each_epoch = False
2091                ds.add_sampler(random_sampler)
2092
2093            subset_sampler = samplers.SequentialSampler(current_split_start_index, size)
2094            ds.add_sampler(subset_sampler)
2095
2096            # add sequential sampler, so that if user calls use_sampler, we will
2097            # get rid of the sequential sampler instead of something we need
2098            ds.add_sampler(samplers.SequentialSampler())
2099
2100            splits.append(ds)
2101
2102            current_split_start_index += size
2103
2104        return tuple(splits)
2105
2106
2107class BucketBatchByLengthDataset(Dataset):
2108    """
2109    The result of applying BucketBatchByLength operator to the input dataset.
2110    """
2111
2112    def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes, element_length_function,
2113                 pad_info, pad_to_bucket_boundary, drop_remainder):
2114        super().__init__(children=input_dataset)
2115
2116        self.column_names = to_list(column_names)
2117        self.bucket_boundaries = replace_none(bucket_boundaries, [])
2118        self.bucket_batch_sizes = replace_none(bucket_batch_sizes, [])
2119        self.element_length_function = element_length_function
2120        self.pad_info = replace_none(pad_info, {})
2121        self.pad_to_bucket_boundary = replace_none(pad_to_bucket_boundary, False)
2122        self.drop_remainder = replace_none(drop_remainder, False)
2123
2124    def parse(self, children=None):
2125        return cde.BucketBatchByLengthNode(children[0], self.column_names, self.bucket_boundaries,
2126                                           self.bucket_batch_sizes, self.element_length_function, self.pad_info,
2127                                           self.pad_to_bucket_boundary, self.drop_remainder)
2128
2129
2130class BatchDataset(Dataset):
2131    """
2132    The result of applying Batch operator to the input dataset.
2133
2134    Args:
2135        input_dataset (Dataset): Input Dataset to be batched.
2136        batch_size (Union[int, function]): The number of rows each batch is created with. An
2137            int or callable which takes exactly 1 parameter, BatchInfo.
2138        drop_remainder (bool, optional): Determines whether or not to drop the last
2139            possibly incomplete batch (default=False). If True, and if there are less
2140            than batch_size rows available to make the last batch, then those rows will
2141            be dropped and not propagated to the child node.
2142        num_parallel_workers (int, optional): Number of workers to process the dataset in parallel (default=None).
2143        per_batch_map (callable, optional): Per batch map callable. A callable which takes
2144            (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch of
2145            Tensors on a given column. The number of lists should match with number of entries in input_columns. The
2146            last parameter of the callable must always be a BatchInfo object.
2147        input_columns (Union[str, list[str]], optional): List of names of the input columns. The size of the list must
2148            match with signature of per_batch_map callable.
2149        output_columns (Union[str, list[str]], optional): List of names assigned to the columns outputted by
2150            the last operation. This parameter is mandatory if len(input_columns) !=
2151            len(output_columns). The size of this list must match the number of output
2152            columns of the last operation. (default=None, output columns will have the same
2153            name as the input columns, i.e., the columns will be replaced).
2154        column_order (Union[str, list[str]], optional): Specifies the list of all the columns you need in the whole
2155                dataset. The parameter is required when len(input_column) != len(output_column). Caution: the list here
2156                is not just the columns specified in parameter input_columns and output_columns.
2157        pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
2158            will pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0.
2159        max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy
2160            data between processes.  This is only used if python_multiprocessing is set to True (default 16 MB).
2161
2162    """
2163
2164    def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
2165                 input_columns=None, output_columns=None, column_order=None, pad_info=None,
2166                 python_multiprocessing=False, max_rowsize=16):
2167        super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers)
2168
2169        if BatchDataset._is_ancestor_of_repeat(input_dataset):
2170            logger.warning("Repeat is located before batch, data from two epochs can be batched together.")
2171
2172        BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
2173
2174        # if batch_size is callable, set batch_size to 1 and batch_size_func to that callable function
2175        self.batch_size = batch_size if not callable(batch_size) else 1
2176        self.batch_size_func = None if not callable(batch_size) else batch_size
2177
2178        self.drop_remainder = replace_none(drop_remainder, False)
2179
2180        self.per_batch_map = per_batch_map
2181
2182        self.input_columns = to_list(input_columns)
2183        self.output_columns = to_list(output_columns)
2184        self.column_order = to_list(column_order)
2185
2186        self.pad = bool(pad_info is not None)
2187        self.pad_info = replace_none(pad_info, dict())
2188
2189        self.python_multiprocessing = python_multiprocessing
2190        self.process_pool = None
2191        self.hook = None
2192        self.pids = []
2193        self.eot = None
2194        self.watch_dog = None
2195        self.max_rowsize = max_rowsize
2196
2197    def parse(self, children=None):
2198        return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, self.pad, self.input_columns,
2199                             self.output_columns, self.column_order, self.batch_size_func, self.per_batch_map,
2200                             self.pad_info)
2201
2202    @staticmethod
2203    def _is_ancestor_of_repeat(dataset):
2204        """
2205        Utility function to find the case where repeat is used before batch.
2206
2207        Args:
2208             dataset (Dataset): Dataset to be checked.
2209
2210        Returns:
2211            bool, whether repeat is used before batch.
2212        """
2213        if isinstance(dataset, RepeatDataset):
2214            return True
2215        flag = False
2216        for input_dataset in dataset.children:
2217            flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset)
2218        return flag
2219
2220    @staticmethod
2221    def _update_batch_size_for_syncwait(dataset, batch_size):
2222        """
2223        Utility function to notify batch size to sync_wait.
2224
2225        Args:
2226             dataset (Dataset): Dataset to be checked.
2227             batch_size (int): batch size to notify.
2228        """
2229        if isinstance(dataset, SyncWaitDataset):
2230            dataset.update_sync_batch_size(batch_size)
2231        for input_dataset in dataset.children:
2232            BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
2233
2234    def __deepcopy__(self, memodict):
2235        return self.__safe_deepcopy__(memodict, exclude=("per_batch_map", "batch_size_func", "__transfer_dataset__"))
2236
2237    # Iterator bootstrap will be called on iterator construction.
2238    # A deep copy of Dataset object is created prior of iterator_bootstrap.
2239    # This method will create per iterator process pool and bind pyfunc execution to the pool.
2240    def iterator_bootstrap(self):
2241        """
2242        Per iterator bootstrap callback.
2243        """
2244        if self.python_multiprocessing:
2245            if self.per_batch_map is None:
2246                logger.warning("per_batch_map is None so python_multiprocessing does not work.")
2247                return
2248            arg_q_list = []
2249            res_q_list = []
2250
2251            # If user didn't specify num_parallel_workers, set it to default
2252            if self.num_parallel_workers is not None:
2253                num_parallel = self.num_parallel_workers
2254            else:
2255                num_parallel = get_num_parallel_workers()
2256
2257            if get_enable_shared_mem():
2258                _check_shm_usage(num_parallel, 1, self.max_rowsize * self.batch_size, 2)
2259                for _ in range(num_parallel):
2260                    arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size))
2261                    res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size))
2262
2263            # Construct pool with the callable list
2264            # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
2265            self.process_pool = multiprocessing.Pool(processes=num_parallel,
2266                                                     initializer=_pyfunc_worker_init,
2267                                                     initargs=([self.per_batch_map], arg_q_list, res_q_list))
2268
2269            idx = 0
2270            global _OP_NAME, _OP_PROCESS, _LOCK
2271            op_id = _OP_NAME[str(self)]
2272            process_id = {op_id: [self.num_parallel_workers, set()]}
2273            # obtain process id from multiprocessing.pool
2274            for pool in self.process_pool._pool:  # pylint: disable=W0212
2275                process_id[op_id][1].add(pool.pid)
2276                self.pids.append(pool.pid)
2277            with _LOCK:
2278                _OP_PROCESS.update(process_id)
2279
2280            # Wrap per_batch_map into _PythonCallable
2281            self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool, arg_q_list, res_q_list)
2282            self.hook = _ExceptHookHandler()
2283            atexit.register(_mp_pool_exit_preprocess)
2284            # If Python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown.
2285            if sys.version_info >= (3, 8):
2286                atexit.register(self.process_pool.close)
2287            if platform.system().lower() != 'windows':
2288                self.eot = threading.Event()
2289                self.watch_dog = threading.Thread(target=_watch_dog, args=(self.eot, self.pids))
2290                self.watch_dog.daemon = True
2291                self.watch_dog.start()
2292        else:
2293            if self.per_batch_map is not None:
2294                self.per_batch_map = FuncWrapper(self.per_batch_map)
2295
2296    def _abort_watchdog(self):
2297        if not self.eot.is_set():
2298            self.eot.set()
2299
2300    def __del__(self):
2301        if hasattr(self, 'process_pool') and self.process_pool is not None:
2302            self.process_pool.close()
2303        if hasattr(self, 'watch_dog') and self.watch_dog is not None and hasattr(self, 'eot') and self.eot is not None:
2304            self._abort_watchdog()
2305
2306
2307class BatchInfo(cde.CBatchInfo):
2308    """
2309    The information object associates with the current batch of tensors.
2310    """
2311
2312    def get_batch_num(self):
2313        """
2314        Return the batch number of the current batch.
2315        """
2316        return
2317
2318    def get_epoch_num(self):
2319        """
2320        Return the epoch number of the current batch.
2321        """
2322        return
2323
2324
2325class BlockReleasePair:
2326    """
2327    The blocking condition class used by SyncWaitDataset.
2328
2329    Args:
2330        init_release_rows (int): Number of lines to allow through the pipeline.
2331        callback (function): The callback function that will be called when release is called (default=None).
2332    """
2333
2334    def __init__(self, init_release_rows, callback=None):
2335        if isinstance(init_release_rows, int) and init_release_rows <= 0:
2336            raise ValueError("release_rows need to be greater than 0.")
2337        self.row_count = -init_release_rows
2338        self.cv = threading.Condition()
2339        self.callback = callback
2340        self.default_rows = init_release_rows
2341        self.disable = False
2342
2343    def __deepcopy__(self, memodict):
2344        return self
2345
2346    def reset(self):
2347        with self.cv:
2348            self.row_count = -self.default_rows
2349            self.cv.notify_all()
2350
2351    def update_batched_size(self, batch_size):
2352        # sanity check
2353        if isinstance(batch_size, int) and batch_size <= 0:
2354            raise ValueError("batch_size need to be greater than 0.")
2355
2356        # should only use before the pipeline creates
2357        self.row_count *= batch_size
2358        self.default_rows *= batch_size
2359
2360    def block_func(self):
2361        """
2362        Function for handing blocking condition.
2363
2364        Returns:
2365            bool, True.
2366        """
2367        with self.cv:
2368            # if disable is true, the always evaluate to true
2369            not_time_out = self.cv.wait_for(lambda: (self.row_count < 0 or self.disable),
2370                                            timeout=get_callback_timeout())
2371            # time_out will be False if time out occurs
2372            if not not_time_out:
2373                logger.warning("Timeout happened in sync_wait, maybe dataset.sync_update(condition=...) "
2374                               "is not added after dataset.create_dict_iterator(...), now disabling lock.")
2375                self.disable = True
2376            self.row_count += 1
2377        return True
2378
2379    def release_func(self, pass_rows=None, data=None):
2380        with self.cv:
2381            if pass_rows is None:
2382                pass_rows = self.default_rows
2383            self.row_count -= pass_rows
2384            if self.callback is not None:
2385                self.callback(data)
2386            self.cv.notify_all()
2387
2388    def disable_lock(self):
2389        with self.cv:
2390            self.disable = True
2391            self.cv.notify_all()
2392
2393
2394class SyncWaitDataset(Dataset):
2395    """
2396    The result of adding a blocking condition to the input Dataset.
2397
2398    Args:
2399        input_dataset (Dataset): Input dataset to apply flow control.
2400        num_batch (int): Number of batches without blocking at the start of each epoch.
2401        condition_name (str): Condition name that is used to toggle sending next row.
2402        callback (function): Callback function that will be invoked when sync_update is called (default=None).
2403
2404    Raises:
2405        RuntimeError: If condition name already exists.
2406    """
2407
2408    def __init__(self, input_dataset, condition_name, num_batch, callback=None):
2409        super().__init__(children=input_dataset)
2410
2411        # set to the default value, waiting for the batch to update it
2412        self._condition_name = condition_name
2413        if isinstance(num_batch, int) and num_batch <= 0:
2414            raise ValueError("num_batch need to be greater than 0.")
2415
2416        self._pair = BlockReleasePair(num_batch, callback)
2417        if self._condition_name in self.children[0].get_sync_notifiers():
2418            raise RuntimeError("Condition name is already in use.")
2419        logger.info("Please remember to add dataset.sync_update(condition=%s), otherwise hanging will result. "
2420                    "If dataset.sync_update(condition=%s) has already been added, you can ignore the info.",
2421                    condition_name, condition_name)
2422
2423    def parse(self, children=None):
2424        return cde.SyncWaitNode(children[0], self._condition_name, self._pair.block_func)
2425
2426    def get_sync_notifiers(self):
2427        return {**self.children[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}}
2428
2429    def is_sync(self):
2430        return True
2431
2432    def update_sync_batch_size(self, batch_size):
2433        if isinstance(batch_size, int) and batch_size <= 0:
2434            raise ValueError("num_batch need to be greater than 0.")
2435        self._pair.update_batched_size(batch_size)
2436
2437    def disable_sync(self):
2438        logger.info("Disabling Sync")
2439        self._pair.disable_lock()
2440
2441    @staticmethod
2442    def _is_ancestor_of_batch(dataset):
2443        """
2444        Utility function to find the case where sync_wait is used before batch.
2445
2446        Args:
2447             dataset (Dataset): Dataset to be checked.
2448
2449        Returns:
2450            bool, whether sync_wait is used before batch.
2451        """
2452        if isinstance(dataset, BatchDataset):
2453            return True
2454        flag = False
2455        for input_dataset in dataset.children:
2456            flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
2457        return flag
2458
2459    def iterator_bootstrap(self):
2460        self._pair.reset()
2461
2462
2463class ShuffleDataset(Dataset):
2464    """
2465    The result of applying Shuffle operator to the input Dataset.
2466
2467    Args:
2468        input_dataset (Dataset): Input Dataset to be shuffled.
2469        buffer_size (int): Size of the buffer.
2470
2471    Raises:
2472        RuntimeError: If exist sync operators before shuffle.
2473    """
2474
2475    def __init__(self, input_dataset, buffer_size):
2476        super().__init__(children=input_dataset)
2477        self.buffer_size = buffer_size
2478        self.reshuffle_each_epoch = True
2479
2480        if self.is_sync():
2481            raise RuntimeError("No shuffle after sync operators.")
2482
2483    def parse(self, children=None):
2484        return cde.ShuffleNode(children[0], self.buffer_size, self.reshuffle_each_epoch)
2485
2486    def is_shuffled(self):
2487        return True
2488
2489
2490# This wait function is for cleaning zombie subprocesses
2491def wait_pid():
2492    try:
2493        while True:
2494            child_pid, _ = os.waitpid(-1, os.WNOHANG)
2495            if child_pid == 0:
2496                break
2497    except OSError:
2498        # waitpid may be failed for some reasons so we ignore this error
2499        pass
2500
2501
2502# Dataset need _watch_dog thread to monitoring fork multi-processing,
2503# and thread can't be a member function otherwise python won't collect and release resources.
2504def _watch_dog(eot, pids):
2505    """
2506    This thread is for monitoring subprocesses forked by GeneratorDataset/MapDataset/BatchDataset
2507    """
2508    while not eot.is_set():
2509        subprocess_exit_num = 0
2510        # Monitoring and count how many subprocesses already exit
2511        for pid in pids:
2512            try:
2513                p = psutil.Process(pid)
2514                if p.status() == psutil.STATUS_ZOMBIE:
2515                    subprocess_exit_num += 1
2516            except psutil.NoSuchProcess:
2517                subprocess_exit_num += 1
2518        # If find subprocess exit, we will wait for 30s and do some waitpid operations
2519        if subprocess_exit_num > 0:
2520            start = time.time()
2521            while time.time() - start < 30:
2522                # We need to distinguishing get_dataset_size or train finished normally and hang scenario.
2523                # If get_dataset_size or train finished normally, _stop_subprocess can be execute and
2524                # self.need_abort can be set to True. If main process is hang in get(), self.need_abort
2525                # will never set to True, then we wait for 30s and kill main process
2526                if eot.is_set():
2527                    return
2528                # Sometimes subprocess may be zombie, so in 30s we can wait and do some useful tasks(waitpid).
2529                wait_pid()
2530            ## multiprocessing.queue may hang in .get() forever when put() process was killed.
2531            ## We have to exit main process otherwise main process will hang.
2532            logger.error("The subprocess of dataset may exit unexpected or be killed, "
2533                         "main process will exit.")
2534            os.kill(os.getpid(), signal.SIGTERM)
2535
2536
2537# Pyfunc collection for multiprocess pyfunc
2538# This global variable will only be used within subprocesses
2539_GLOBAL_PYFUNC_LIST = []
2540_ARGS_QUEUE = []
2541_RET_QUEUE = []
2542_OP_NAME = dict()
2543_OP_PROCESS = dict()
2544_LOCK = threading.Lock()
2545
2546
2547# Pyfunc worker init function
2548# Python multiprocessing library forbid sending lambda function through pipe.
2549# This init function allow us to add all Python function to a global collection and then fork afterwards.
2550def _pyfunc_worker_init(pyfunc_list, args_queue, ret_queue):
2551    global _GLOBAL_PYFUNC_LIST
2552    global _ARGS_QUEUE
2553    global _RET_QUEUE
2554    _GLOBAL_PYFUNC_LIST = pyfunc_list
2555    _ARGS_QUEUE = args_queue
2556    _RET_QUEUE = ret_queue
2557
2558
2559# Pyfunc worker execution function
2560# All exceptions will be raised to main processes
2561def _pyfunc_worker_exec(index, qid, *args):
2562    """
2563    Internal function for call certain pyfunc in Python process.
2564    """
2565    # Some threads in multiprocess.pool can't process sigint signal,
2566    # and will occur hang problem, so ctrl+c will pass to parent process.
2567    signal.signal(signal.SIGINT, signal.SIG_IGN)
2568
2569    if qid != -1:
2570        # Pass arguments through the Queue instead of directly to remote process
2571        args = _ARGS_QUEUE[qid].get()
2572        try:
2573            r = _GLOBAL_PYFUNC_LIST[index](*args)
2574        except Exception:
2575            return ExceptionHandler(where="in map(or batch) worker and execute python function")
2576        if isinstance(r, tuple):
2577            _RET_QUEUE[qid].put(r)
2578        else:
2579            _RET_QUEUE[qid].put((r,))
2580        return [qid]
2581    # not using shared memory for passing arguments, call function directly
2582    result = None
2583    try:
2584        result = _GLOBAL_PYFUNC_LIST[index](*args)
2585    except Exception:
2586        result = ExceptionHandler(where="in map(or batch) worker and execute python function")
2587    return result
2588
2589
2590# PythonCallable wrapper for multiprocess pyfunc
2591class _PythonCallable:
2592    """
2593    Internal Python function wrapper for multiprocessing pyfunc.
2594    """
2595
2596    def __init__(self, py_callable, idx, pool=None, arg_q=None, res_q=None):
2597        # Original Python callable from user.
2598        self.py_callable = py_callable
2599        # Process pool created for current iterator.
2600        self.pool = pool
2601        # Python callable index for subprocess _GLOBAL_PYFUNC_LIST
2602        self.idx = idx
2603
2604        if pool is not None:
2605            self.queuemap = {}
2606            self.arg_q = arg_q
2607            self.res_q = res_q
2608            self.next_queue = 0
2609
2610    def __call__(self, *args):
2611        if self._pool_is_running() and check_iterator_cleanup() is False:
2612            # arg_q will have 0 size if we are not using shared memory
2613            # if using multi-processing shared queue instead of multiprocess arg passing
2614            if self.arg_q != []:
2615                tid = threading.get_ident()
2616                # Need to register each thread to use a different queue to send data to pool
2617                if not tid in self.queuemap:
2618                    qid = self.next_queue
2619                    self.next_queue = self.next_queue + 1
2620                    self.queuemap[tid] = qid
2621                else:
2622                    qid = self.queuemap[tid]
2623                self.arg_q[qid].put(args)
2624
2625                # This call will send the tensors along with Python callable index to the process pool.
2626                # Block, yield GIL. Current thread will reacquire GIL once result is returned.
2627                if self._pool_is_running() and check_iterator_cleanup() is False:
2628                    result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, qid, []])
2629                else:
2630                    return self.py_callable(*args)
2631            else:
2632                result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, -1, *args])
2633
2634            # todo this check might be wrong
2635            while check_iterator_cleanup() is False:
2636                try:
2637                    if self.arg_q != []:
2638                        r = result.get(30)
2639                        if isinstance(r, ExceptionHandler):
2640                            r.reraise()
2641                        if r[0] != qid:
2642                            raise Exception("In PyCallable, got results from wrong thread")
2643                        r = self.res_q[qid].get()
2644                        return r
2645                    r = result.get(30)
2646                    if isinstance(r, ExceptionHandler):
2647                        r.reraise()
2648                    return r
2649                except multiprocessing.TimeoutError:
2650                    continue
2651                except KeyboardInterrupt:
2652                    _set_iterator_cleanup()
2653                    self.pool.close()
2654                    self.pool.join()
2655                    raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt.")
2656            return (None,)
2657        # Invoke original Python callable in master process in case the pool is gone.
2658        return self.py_callable(*args)
2659
2660    def to_json(self):
2661        return self.py_callable.to_json()
2662
2663    def _pool_is_running(self):
2664        # note here: the RUN state of python3.7 and python3.8 is different:
2665        # python3.7: RUN = 0
2666        # python3.8: RUN = "RUN"
2667        # so we use self.pool._state == RUN instead and we can't use _state == 0 any more.
2668        if self.pool is not None and self.pool._state == RUN:  # pylint: disable=W0212
2669            return True
2670        return False
2671
2672
2673def _mp_pool_exit_preprocess():
2674    if check_iterator_cleanup() is False:
2675        # Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async
2676        # applied to the multiprocessing task to prevent multiprocessing from hang when exiting
2677        _set_iterator_cleanup()
2678        time.sleep(3)
2679
2680
2681class _ExceptHookHandler:
2682    def __init__(self):
2683        sys.excepthook = self.__handler_exception
2684
2685    def __handler_exception(self, ex_type, value, tb):
2686        logger.error("Uncaught exception: ", exc_info=(ex_type, value, tb))
2687        _mp_pool_exit_preprocess()
2688
2689
2690class MapDataset(Dataset):
2691    """
2692    The result of applying the Map operator to the input Dataset.
2693
2694    Args:
2695        input_dataset (Dataset): Input Dataset to be mapped.
2696        operations (TensorOp): A function mapping a nested structure of tensors
2697            to another nested structure of tensor (default=None).
2698        input_columns (Union[str, list[str]]): List of names of the input columns
2699            (default=None, the operations will be applied on the first columns in the dataset).
2700            The size of the list should match the number of inputs of the first operator.
2701        output_columns (Union[str, list[str]], optional): List of names of the output columns.
2702            The size of the list should match the number of outputs of the last operator
2703            (default=None, output columns will be the input columns, i.e., the columns will
2704            be replaced).
2705        column_order (list[str], optional): Specifies the list of all the columns you need in the whole
2706            dataset. The parameter is required when len(input_column) != len(output_column). Caution: the list here
2707            is not just the columns specified in parameter input_columns and output_columns.
2708        num_parallel_workers (int, optional): Number of workers to process the dataset
2709            in parallel (default=None).
2710        python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
2711            option could be beneficial if the Python operation is computational heavy (default=False).
2712        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
2713            (default=None, which means no cache is used).
2714        callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None)
2715        max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy
2716            data between processes.  This is only used if python_multiprocessing is set to True (default 16 MB).
2717
2718        Raises:
2719            ValueError: If len(input_columns) != len(output_columns) and column_order is not specified.
2720    """
2721
2722    def __init__(self, input_dataset, operations=None, input_columns=None, output_columns=None, column_order=None,
2723                 num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16):
2724        super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers, cache=cache)
2725        self.operations = to_list(operations)
2726        self.operations = py_transforms.Compose.reduce(self.operations)
2727        self.input_columns = to_list(input_columns)
2728        self.output_columns = to_list(output_columns)
2729        self.column_order = replace_none(column_order, [])
2730
2731        #  If output_columns were not provided then use input_columns
2732        self.output_columns = self.input_columns if not self.output_columns else self.output_columns
2733
2734        if self.input_columns and self.output_columns \
2735                and len(self.input_columns) != len(self.output_columns) \
2736                and not self.column_order:
2737            raise ValueError("When length of input_columns and output_columns are not equal,"
2738                             " column_order must be specified.")
2739
2740        self.python_multiprocessing = python_multiprocessing
2741        self.process_pool = None
2742        self.hook = None
2743        self.pids = []
2744        self.eot = None
2745        self.watch_dog = None
2746
2747        self.callbacks = to_list(callbacks)
2748        self.max_rowsize = max_rowsize
2749
2750    def parse(self, children=None):
2751        operations = []
2752        for op in self.operations:
2753            if op and getattr(op, 'parse', None):
2754                operations.append(op.parse())
2755            else:
2756                operations.append(op)
2757
2758        callbacks = [cb.create_runtime_obj() for cb in self.callbacks]
2759        return cde.MapNode(children[0], operations, self.input_columns, self.output_columns, self.column_order,
2760                           callbacks)
2761
2762    def __deepcopy__(self, memodict):
2763        return self.__safe_deepcopy__(memodict, exclude=("operations", "callbacks", "__transfer_dataset__"))
2764
2765    # Iterator bootstrap will be called on iterator construction.
2766    # A deep copy of Dataset object is created prior of iterator_bootstrap.
2767    # This method will create per iterator process pool and bind pyfunc execution to the pool.
2768    def iterator_bootstrap(self):
2769        """
2770        Per iterator bootstrap callback.
2771        """
2772
2773        if self.python_multiprocessing:
2774            iter_specific_operations = []
2775            callable_list = []
2776            arg_q_list = []
2777            res_q_list = []
2778
2779            # If user didn't specify num_parallel_workers, set it to default
2780            if self.num_parallel_workers is not None:
2781                num_parallel = self.num_parallel_workers
2782            else:
2783                num_parallel = get_num_parallel_workers()
2784
2785            if get_enable_shared_mem():
2786                _check_shm_usage(num_parallel, 1, self.max_rowsize, 2)
2787                for _ in range(num_parallel):
2788                    arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize))
2789                    res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize))
2790
2791            # Pass #1, look for Python callables and build list
2792            for op in self.operations:
2793                # our c transforms is now callable and should not be run in Python multithreading
2794                if callable(op) and str(op).find("c_transform") < 0:
2795                    callable_list.append(op)
2796
2797            if callable_list:
2798                # Construct pool with the callable list
2799                # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
2800                self.process_pool = multiprocessing.Pool(processes=num_parallel,
2801                                                         initializer=_pyfunc_worker_init,
2802                                                         initargs=(callable_list, arg_q_list, res_q_list))
2803
2804                # Pass #2
2805                idx = 0
2806                global _OP_NAME, _OP_PROCESS, _LOCK
2807                op_id = _OP_NAME[str(self)]
2808                # obtain process id from multiprocessing.pool
2809                process_id = {op_id: [self.num_parallel_workers, set()]}
2810                for pool in self.process_pool._pool:  # pylint: disable=W0212
2811                    process_id[op_id][1].add(pool.pid)
2812                    self.pids.append(pool.pid)
2813                with _LOCK:
2814                    _OP_PROCESS.update(process_id)
2815                for op in self.operations:
2816                    # our c transforms is now callable and should not be run in Python multithreading
2817                    if callable(op) and str(op).find("c_transform") < 0:
2818                        # Wrap Python callable into _PythonCallable
2819                        iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool,
2820                                                                        arg_q_list, res_q_list))
2821                        idx += 1
2822                    else:
2823                        # CPP ops remain the same
2824                        iter_specific_operations.append(op)
2825                self.operations = iter_specific_operations
2826                self.hook = _ExceptHookHandler()
2827                atexit.register(_mp_pool_exit_preprocess)
2828                # If Python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown.
2829                if sys.version_info >= (3, 8):
2830                    atexit.register(self.process_pool.close)
2831                if platform.system().lower() != 'windows':
2832                    self.eot = threading.Event()
2833                    self.watch_dog = threading.Thread(target=_watch_dog, args=(self.eot, self.pids))
2834                    self.watch_dog.daemon = True
2835                    self.watch_dog.start()
2836
2837    def _abort_watchdog(self):
2838        if not self.eot.is_set():
2839            self.eot.set()
2840
2841    def __del__(self):
2842        if hasattr(self, 'process_pool') and self.process_pool is not None:
2843            self.process_pool.close()
2844            self.process_pool.join()
2845        if hasattr(self, 'watch_dog') and self.watch_dog is not None and hasattr(self, 'eot') and self.eot is not None:
2846            self._abort_watchdog()
2847
2848
2849class FilterDataset(Dataset):
2850    """
2851    The result of applying filter predicate to the input Dataset.
2852
2853    Args:
2854        input_dataset (Dataset): Input Dataset to be mapped.
2855        predicate (callable): Python callable which returns a boolean value. If False then filter the element.
2856        input_columns (Union[str, list[str]], optional): List of names of the input columns
2857        (default=None, the predicate will be applied to all columns in the dataset).
2858        num_parallel_workers (int, optional): Number of workers to process the dataset
2859            in parallel (default=None).
2860    """
2861
2862    def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None):
2863        super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers)
2864        self.predicate = lambda *args: bool(predicate(*args))
2865        self.input_columns = to_list(input_columns)
2866
2867    def parse(self, children=None):
2868        return cde.FilterNode(children[0], self.predicate, self.input_columns)
2869
2870
2871class RepeatDataset(Dataset):
2872    """
2873    The result of applying Repeat operator to the input Dataset.
2874
2875    Args:
2876        input_dataset (Dataset): Input Dataset to be repeated.
2877        count (int): Number of times the dataset will be repeated (default=-1, repeat indefinitely).
2878    """
2879
2880    def __init__(self, input_dataset, count):
2881        super().__init__(children=input_dataset)
2882        self.count = replace_none(count, -1)
2883
2884    def parse(self, children=None):
2885        return cde.RepeatNode(children[0], self.count)
2886
2887
2888class SkipDataset(Dataset):
2889    """
2890    The result of applying Skip operator to the input Dataset.
2891
2892    Args:
2893        input_dataset (Dataset): Input dataset to have elements skipped.
2894        count (int): Number of elements to be skipped in the dataset.
2895    """
2896
2897    def __init__(self, input_dataset, count):
2898        super().__init__(input_dataset)
2899        self.count = count
2900
2901    def parse(self, children=None):
2902        return cde.SkipNode(children[0], self.count)
2903
2904
2905class TakeDataset(Dataset):
2906    """
2907    The result of applying Take operator to the input Dataset.
2908
2909    Args:
2910        input_dataset (Dataset): Input Dataset to have elements taken from.
2911        count (int): Number of elements to be taken from the dataset.
2912    """
2913
2914    def __init__(self, input_dataset, count):
2915        super().__init__(children=input_dataset)
2916        self.count = count
2917
2918    def parse(self, children=None):
2919        return cde.TakeNode(children[0], self.count)
2920
2921
2922class ZipDataset(Dataset):
2923    """
2924    The result of applying Zip operator to the input Dataset.
2925
2926    Args:
2927        datasets (tuple): A tuple of datasets to be zipped together.
2928
2929    Raises:
2930        TypeError: If dataset is not an instance of Dataset.
2931    """
2932
2933    def __init__(self, datasets):
2934        super().__init__(children=datasets)
2935
2936    def parse(self, children=None):
2937        return cde.ZipNode(children)
2938
2939    def is_sync(self):
2940        return any([c.is_sync() for c in self.children])
2941
2942
2943class ConcatDataset(Dataset):
2944    """
2945    The result of applying concat dataset operator to the input Dataset.
2946
2947    Args:
2948        datasets (list): A list of datasets to be concatenated together.
2949
2950    Raises:
2951        TypeError: If dataset is not an instance of Dataset.
2952        ValueError: If there is no samples in the one of the datasets.
2953    """
2954
2955    def __init__(self, datasets):
2956        super().__init__(children=datasets)
2957        for dataset in datasets:
2958            if not isinstance(dataset, Dataset):
2959                raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset))
2960        self.datasets = datasets
2961        self._sampler = samplers.SequentialSampler(num_samples=None)
2962
2963        self.children_sizes_ = [c.get_dataset_size() for c in self.children]
2964        child_index = 0
2965        for item in self.children_sizes_:
2966            if item == 0:
2967                raise ValueError("There are no samples in the dataset number %d. Please make sure there are "
2968                                 "valid samples in the dataset." % child_index)
2969            child_index += 1
2970
2971        # _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
2972        # whether the data set is mappable. The second element of pair is length of the dataset
2973        self._children_flag_and_nums = []
2974
2975        # _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
2976        # the valid position of the dataset corresponding to the subscript when sampling
2977        self._children_start_end_index_ = []
2978        for index, child in enumerate(self.children):
2979            tem_list = [-1, -1]
2980            self._children_start_end_index_.append(tem_list)
2981            dataset_len = self.children_sizes_[index]
2982            if isinstance(child, GeneratorDataset) and not hasattr(child.source, "__getitem__"):
2983                dataset_len = 0
2984                self.children_sizes_[index] = 0
2985
2986            if isinstance(child, MappableDataset):
2987                self._children_flag_and_nums.append((0, dataset_len))
2988            else:
2989                self._children_flag_and_nums.append((1, dataset_len))
2990
2991    def parse(self, children=None):
2992        return cde.ConcatNode(children, self._sampler, self._children_flag_and_nums, self._children_start_end_index_)
2993
2994    def use_sampler(self, sampler):
2995        """
2996        Set the distributedSampler to concat dataset
2997
2998        Args:
2999            sampler (Sampler): The sampler to use for the current dataset.
3000                Currently supported: DistributedSampler.
3001
3002        Raises:
3003            TypeError: If the sampler is not an instance of DistributedSampler
3004            ValueError: If the parameter shuffle of sampler is True
3005            ValueError: If the parameter NumSamples of sampler is not None.
3006            ValueError: If num_shards <=0.
3007        """
3008        if not isinstance(sampler, samplers.DistributedSampler):
3009            raise TypeError("The parameter %s of concat must be DistributedSampler!" % sampler)
3010
3011        if sampler.is_shuffled():
3012            raise ValueError("The parameter shuffle of DistributedSampler must be False!")
3013
3014        if sampler.num_shards <= 0:
3015            raise ValueError("The parameter num_shards of DistributedSampler must be positive int!")
3016
3017        if sampler.get_num_samples() is not None:
3018            raise ValueError("The parameter num_samples of DistributedSampler is not support to be set!")
3019
3020        self.dataset_size = None
3021
3022        self._sampler = sampler
3023        cumulative_samples_nums = 0
3024        for index, child in enumerate(self.children):
3025            if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None:
3026                raise ValueError("The parameter NumSamples of %s is not support to be set!" % child)
3027
3028            if isinstance(child, BatchDataset):
3029                raise TypeError("The parameter %s of concat must not be BatchDataset!" % child)
3030
3031            # if child is mappable and the length is greater than 0
3032            if not self._children_flag_and_nums[index][0] and self._children_flag_and_nums[index][1]:
3033
3034                tem_value = cumulative_samples_nums + self._children_flag_and_nums[index][1]
3035
3036                if not self._children_flag_and_nums[index][1] >= sampler.num_shards:
3037                    if tem_value < sampler.num_shards:
3038                        self._children_start_end_index_[index][0] = cumulative_samples_nums
3039                        self._children_start_end_index_[index][1] = tem_value
3040                    else:
3041                        self._children_start_end_index_[index][0] = cumulative_samples_nums
3042                        self._children_start_end_index_[index][1] = tem_value % sampler.num_shards
3043
3044                tem_sampler = copy.deepcopy(sampler)
3045                tem_sampler.set_offset(cumulative_samples_nums)
3046                child.use_sampler(tem_sampler)
3047
3048            cumulative_samples_nums += self.children_sizes_[index]
3049            cumulative_samples_nums %= sampler.num_shards
3050
3051
3052class RenameDataset(Dataset):
3053    """
3054    The result of applying Rename operator to the input Dataset.
3055
3056    Args:
3057        input_dataset (Dataset): Input Dataset to be Renamed.
3058        input_columns (Union[str, list[str]]): List of names of the input columns.
3059        output_columns (Union[str, list[str]]): List of names of the output columns.
3060    """
3061
3062    def __init__(self, input_dataset, input_columns, output_columns):
3063        super().__init__(children=input_dataset)
3064        self.input_column_names = to_list(input_columns)
3065        self.output_column_names = to_list(output_columns)
3066
3067    def parse(self, children=None):
3068        return cde.RenameNode(children[0], self.input_column_names, self.output_column_names)
3069
3070
3071def to_list(items):
3072    if items is None:
3073        return []
3074    if isinstance(items, tuple):
3075        return list(items)
3076    if not isinstance(items, list):
3077        return [items]
3078    return items
3079
3080
3081class ProjectDataset(Dataset):
3082    """
3083    The result of applying Project operator to the input Dataset.
3084
3085    Args:
3086        input_dataset (Dataset): Input Dataset to be Projected.
3087        columns (Union[str, list[str]]): List of names of the columns to project.
3088    """
3089
3090    def __init__(self, input_dataset, columns):
3091        super().__init__(children=input_dataset)
3092        self.columns = to_list(columns)
3093
3094    def parse(self, children=None):
3095        return cde.ProjectNode(children[0], self.columns)
3096
3097
3098class _ToDevice:
3099    """
3100    Internal class to handle sending data to device.
3101    """
3102
3103    def __init__(self, dataset, num_epochs):
3104        ir_tree, self.api_tree = dataset.create_ir_tree()
3105
3106        self._runtime_context = cde.PythonRuntimeContext()
3107        self._runtime_context.Init()
3108        self._to_device = cde.ToDevice(num_epochs)
3109        self._to_device.Init(ir_tree)
3110        self._runtime_context.AssignConsumer(self._to_device)
3111
3112        ITERATORS_LIST.append(weakref.ref(self))
3113        _unset_iterator_cleanup()
3114
3115    def send(self):
3116        self._to_device.Send()
3117
3118    def stop_send(self):
3119        """
3120        send stop send signal to pipeline, it is used when end of sequence is sent at the epoch end.
3121        """
3122        self._to_device.StopSend()
3123
3124    def continue_send(self):
3125        """
3126        send continue send signal to pipeline, it is used when end of sequence is sent at the epoch end.
3127        """
3128        self._to_device.ContinueSend()
3129
3130    def get_data_info(self):
3131        """
3132        Get type and shape of current batch.
3133        """
3134        return self._to_device.GetDataInfo()
3135
3136    def release(self):
3137        """
3138        Manually terminate Device Queue instead of relying on out of scope destruction.
3139        """
3140        if hasattr(self, '_runtime_context') and self._runtime_context:
3141            if hasattr(self, '_to_device') and self._to_device:
3142                self._runtime_context.Terminate()
3143                del self._to_device
3144            del self._runtime_context
3145
3146    def __deepcopy__(self, memodict):
3147        return self
3148
3149
3150class TransferDataset(Dataset):
3151    """
3152    The result of applying TDT operator to the input Dataset.
3153
3154    Args:
3155        input_dataset (Dataset): Input Dataset to be transferred.
3156        send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
3157        create_data_info_queue (bool, optional): Whether to create queue which stores
3158            types and shapes of data or not (default=False).
3159
3160    Raises:
3161        TypeError: If device_type is empty.
3162        ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'.
3163        RuntimeError: If dataset is unknown.
3164    """
3165
3166    def __init__(self, input_dataset, send_epoch_end=True, create_data_info_queue=False):
3167        super().__init__(children=input_dataset)
3168        self.queue_name = str(uuid.uuid1())
3169        self.device_type = context.get_context("device_target") if context else "CPU"
3170        self.device_id = context.get_context("device_id") if context else 0
3171
3172        self._send_epoch_end = replace_none(send_epoch_end, True)
3173        self._create_data_info_queue = create_data_info_queue
3174        self._to_device = None
3175
3176    def parse(self, children=None):
3177        total_batch = 0
3178        if hasattr(self.children[0], "__total_batch__"):
3179            total_batch = self.children[0].__total_batch__
3180        return cde.TransferNode(children[0], self.queue_name, self.device_type, self.device_id, self._send_epoch_end,
3181                                total_batch, self._create_data_info_queue)
3182
3183    def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
3184        raise RuntimeError("TransferDataset is not iterable.")
3185
3186    def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True):
3187        raise RuntimeError("TransferDataset is not iterable.")
3188
3189    def __iter__(self):
3190        raise RuntimeError("TransferDataset is not iterable.")
3191
3192    def output_shapes(self):
3193        raise RuntimeError("TransferDataset does not support obtaining output_shapes.")
3194
3195    def output_types(self):
3196        raise RuntimeError("TransferDataset does not support obtaining output_types.")
3197
3198    @check_to_device_send
3199    def send(self, num_epochs=-1):
3200        """
3201        Send to device
3202        """
3203        if Dataset._noop_mode():
3204            return
3205        if self._to_device is not None:
3206            del self._to_device
3207        self._to_device = _ToDevice(self, num_epochs)
3208        self._to_device.send()
3209
3210    def stop_send(self):
3211        if self._to_device is not None:
3212            self._to_device.stop_send()
3213
3214    def continue_send(self):
3215        if self._to_device is not None:
3216            self._to_device.continue_send()
3217
3218    def get_data_info(self):
3219        """
3220        Get type and shape of current batch
3221        """
3222        if self._to_device is not None:
3223            return self._to_device.get_data_info()
3224        raise RuntimeError("Calling get_data_info with bad state.")
3225
3226    def release(self):
3227        """
3228        Manually terminate Device Queue instead of relying on out of scope destruction.
3229        """
3230        if self._to_device is not None:
3231            self._to_device.release()
3232
3233
3234class RangeDataset(MappableDataset):
3235    """
3236    A source dataset that reads and parses datasets stored on disk in a range.
3237
3238    Args:
3239        start (int): Starting index.
3240        stop (int): Ending index.
3241        step (int): Step size in the range specified by start and stop.
3242    """
3243
3244    def __init__(self, start, stop, step):
3245        super().__init__()
3246        self.start = start
3247        self.stop = stop
3248        self.step = step
3249
3250    def parse(self, children=None):
3251        raise NotImplementedError("Dataset has to implement parse method.")
3252
3253    def is_shuffled(self):
3254        return False
3255
3256    def is_sharded(self):
3257        return False
3258
3259    def get_dataset_size(self):
3260        if self.dataset_size is None:
3261            self.dataset_size = math.ceil((self.stop - self.start) / self.step)
3262        return self.dataset_size
3263
3264
3265class ImageFolderDataset(MappableDataset):
3266    """
3267    A source dataset that reads images from a tree of directories.
3268    All images within one folder have the same label.
3269
3270    The generated dataset has two columns: :py:obj:`[image, label]`.
3271    The tensor of column :py:obj:`image` is of the uint8 type.
3272    The tensor of column :py:obj:`label` is of a scalar of uint32 type.
3273
3274    Args:
3275        dataset_dir (str): Path to the root directory that contains the dataset.
3276        num_samples (int, optional): The number of images to be included in the dataset
3277            (default=None, all images).
3278        num_parallel_workers (int, optional): Number of workers to read the data
3279            (default=None, set in the config).
3280        shuffle (bool, optional): Whether or not to perform shuffle on the dataset
3281            (default=None, expected order behavior shown in the table).
3282        sampler (Sampler, optional): Object used to choose samples from the
3283            dataset (default=None, expected order behavior shown in the table).
3284        extensions (list[str], optional): List of file extensions to be
3285            included in the dataset (default=None).
3286        class_indexing (dict, optional): A str-to-int mapping from folder name to index
3287            (default=None, the folder names will be sorted
3288            alphabetically and each class will be given a
3289            unique index starting from 0).
3290        decode (bool, optional): Decode the images after reading (default=False).
3291        num_shards (int, optional): Number of shards that the dataset will be divided
3292            into (default=None). When this argument is specified, `num_samples` reflects
3293            the maximum sample number of per shard.
3294        shard_id (int, optional): The shard ID within num_shards (default=None). This
3295            argument can only be specified when num_shards is also specified.
3296        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
3297            (default=None, which means no cache is used).
3298
3299    Raises:
3300        RuntimeError: If dataset_dir does not contain data files.
3301        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
3302        RuntimeError: If sampler and shuffle are specified at the same time.
3303        RuntimeError: If sampler and sharding are specified at the same time.
3304        RuntimeError: If num_shards is specified but shard_id is None.
3305        RuntimeError: If shard_id is specified but num_shards is None.
3306        RuntimeError: If class_indexing is not a dictionary.
3307        ValueError: If shard_id is invalid (< 0 or >= num_shards).
3308
3309    Note:
3310        - The shape of the image column is [image_size] if decode flag is False, or [H,W,C] otherwise.
3311        - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
3312          The table below shows what input arguments are allowed and their expected behavior.
3313
3314    .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
3315       :widths: 25 25 50
3316       :header-rows: 1
3317
3318       * - Parameter `sampler`
3319         - Parameter `shuffle`
3320         - Expected Order Behavior
3321       * - None
3322         - None
3323         - random order
3324       * - None
3325         - True
3326         - random order
3327       * - None
3328         - False
3329         - sequential order
3330       * - Sampler object
3331         - None
3332         - order defined by sampler
3333       * - Sampler object
3334         - True
3335         - not allowed
3336       * - Sampler object
3337         - False
3338         - not allowed
3339
3340    Examples:
3341        >>> image_folder_dataset_dir = "/path/to/image_folder_dataset_directory"
3342        >>>
3343        >>> # 1) Read all samples (image files) in image_folder_dataset_dir with 8 threads
3344        >>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir,
3345        ...                                 num_parallel_workers=8)
3346        >>>
3347        >>> # 2) Read all samples (image files) from folder cat and folder dog with label 0 and 1
3348        >>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir,
3349        ...                                 class_indexing={"cat":0, "dog":1})
3350        >>>
3351        >>> # 3) Read all samples (image files) in image_folder_dataset_dir with extensions .JPEG and .png (case sensitive)
3352        >>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir,
3353        ...                                 extensions=[".JPEG", ".png"])
3354
3355    About ImageFolderDataset:
3356
3357    You can construct the following directory structure from your dataset files and read by MindSpore's API.
3358
3359    .. code-block::
3360
3361        .
3362        └── image_folder_dataset_directory
3363             ├── class1
3364             │    ├── 000000000001.jpg
3365             │    ├── 000000000002.jpg
3366             │    ├── ...
3367             ├── class2
3368             │    ├── 000000000001.jpg
3369             │    ├── 000000000002.jpg
3370             │    ├── ...
3371             ├── class3
3372             │    ├── 000000000001.jpg
3373             │    ├── 000000000002.jpg
3374             │    ├── ...
3375             ├── classN
3376             ├── ...
3377    """
3378
3379    @check_imagefolderdataset
3380    def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None,
3381                 extensions=None, class_indexing=None, decode=False, num_shards=None, shard_id=None, cache=None):
3382        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
3383                         shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
3384
3385        self.dataset_dir = dataset_dir
3386        self.extensions = replace_none(extensions, [])
3387        self.class_indexing = replace_none(class_indexing, {})
3388        self.decode = replace_none(decode, False)
3389
3390    def parse(self, children=None):
3391        return cde.ImageFolderNode(self.dataset_dir, self.decode, self.sampler, self.extensions, self.class_indexing)
3392
3393
3394class MnistDataset(MappableDataset):
3395    """
3396    A source dataset for reading and parsing the MNIST dataset.
3397
3398    The generated dataset has two columns :py:obj:`[image, label]`.
3399    The tensor of column :py:obj:`image` is of the uint8 type.
3400    The tensor of column :py:obj:`label` is a scalar of the uint32 type.
3401
3402    Args:
3403        dataset_dir (str): Path to the root directory that contains the dataset.
3404        usage (str, optional): Usage of this dataset, can be `train`, `test` or `all` . `train` will read from 60,000
3405            train samples, `test` will read from 10,000 test samples, `all` will read from all 70,000 samples.
3406            (default=None, will read all samples)
3407        num_samples (int, optional): The number of images to be included in the dataset
3408            (default=None, will read all images).
3409        num_parallel_workers (int, optional): Number of workers to read the data
3410            (default=None, will use value set in the config).
3411        shuffle (bool, optional): Whether or not to perform shuffle on the dataset
3412            (default=None, expected order behavior shown in the table).
3413        sampler (Sampler, optional): Object used to choose samples from the
3414            dataset (default=None, expected order behavior shown in the table).
3415        num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
3416            When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
3417        shard_id (int, optional): The shard ID within `num_shards` (default=None). This
3418            argument can only be specified when `num_shards` is also specified.
3419        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
3420            (default=None, which means no cache is used).
3421
3422    Raises:
3423        RuntimeError: If dataset_dir does not contain data files.
3424        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
3425        RuntimeError: If sampler and shuffle are specified at the same time.
3426        RuntimeError: If sampler and sharding are specified at the same time.
3427        RuntimeError: If num_shards is specified but shard_id is None.
3428        RuntimeError: If shard_id is specified but num_shards is None.
3429        ValueError: If shard_id is invalid (< 0 or >= num_shards).
3430
3431    Note:
3432        - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
3433          The table below shows what input arguments are allowed and their expected behavior.
3434
3435    .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
3436       :widths: 25 25 50
3437       :header-rows: 1
3438
3439       * - Parameter `sampler`
3440         - Parameter `shuffle`
3441         - Expected Order Behavior
3442       * - None
3443         - None
3444         - random order
3445       * - None
3446         - True
3447         - random order
3448       * - None
3449         - False
3450         - sequential order
3451       * - Sampler object
3452         - None
3453         - order defined by sampler
3454       * - Sampler object
3455         - True
3456         - not allowed
3457       * - Sampler object
3458         - False
3459         - not allowed
3460
3461    Examples:
3462        >>> mnist_dataset_dir = "/path/to/mnist_dataset_directory"
3463        >>>
3464        >>> # Read 3 samples from MNIST dataset
3465        >>> dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, num_samples=3)
3466        >>>
3467        >>> # Note: In mnist_dataset dataset, each dictionary has keys "image" and "label"
3468
3469    About MNIST dataset:
3470
3471    The MNIST database of handwritten digits has a training set of 60,000 examples,
3472    and a test set of 10,000 examples. It is a subset of a larger set available from
3473    NIST. The digits have been size-normalized and centered in a fixed-size image.
3474
3475    Here is the original MNIST dataset structure.
3476    You can unzip the dataset files into this directory structure and read by MindSpore's API.
3477
3478    .. code-block::
3479
3480        .
3481        └── mnist_dataset_dir
3482             ├── t10k-images-idx3-ubyte
3483             ├── t10k-labels-idx1-ubyte
3484             ├── train-images-idx3-ubyte
3485             └── train-labels-idx1-ubyte
3486
3487    Citation:
3488
3489    .. code-block::
3490
3491        @article{lecun2010mnist,
3492        title        = {MNIST handwritten digit database},
3493        author       = {LeCun, Yann and Cortes, Corinna and Burges, CJ},
3494        journal      = {ATT Labs [Online]},
3495        volume       = {2},
3496        year         = {2010},
3497        howpublished = {http://yann.lecun.com/exdb/mnist}
3498        }
3499    """
3500
3501    @check_mnist_cifar_dataset
3502    def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None,
3503                 sampler=None, num_shards=None, shard_id=None, cache=None):
3504        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
3505                         shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
3506
3507        self.dataset_dir = dataset_dir
3508        self.usage = replace_none(usage, "all")
3509
3510    def parse(self, children=None):
3511        return cde.MnistNode(self.dataset_dir, self.usage, self.sampler)
3512
3513
3514class MindDataset(MappableDataset):
3515    """
3516    A source dataset for reading and parsing MindRecord dataset.
3517
3518    The columns of generated dataset depend on the source MindRecord files.
3519
3520    Args:
3521        dataset_file (Union[str, list[str]]): If dataset_file is a str, it represents for
3522            a file name of one component of a mindrecord source, other files with identical source
3523            in the same path will be found and loaded automatically. If dataset_file is a list,
3524            it represents for a list of dataset files to be read directly.
3525        columns_list (list[str], optional): List of columns to be read (default=None).
3526        num_parallel_workers (int, optional): The number of readers (default=None).
3527        shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
3528            (default=None, performs global shuffle).
3529            If shuffle is False, no shuffling will be performed;
3530            If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
3531            Otherwise, there are three levels of shuffling:
3532
3533            - Shuffle.GLOBAL: Global shuffle of all rows of data in dataset.
3534
3535            - Shuffle.FILES: Shuffle the file sequence but keep the order of data within each file.
3536
3537            - Shuffle.INFILE: Keep the file sequence the same but shuffle the data within each file.
3538
3539        num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
3540            When this argument is specified, 'num_samples' reflects the maximum sample number of per shard.
3541        shard_id (int, optional): The shard ID within num_shards (default=None). This
3542            argument can only be specified when num_shards is also specified.
3543        sampler (Sampler, optional): Object used to choose samples from the
3544            dataset (default=None, sampler is exclusive
3545            with shuffle and block_reader). Support list: SubsetRandomSampler,
3546            PkSampler, RandomSampler, SequentialSampler, DistributedSampler.
3547        padded_sample (dict, optional): Samples will be appended to dataset, where
3548            keys are the same as column_list.
3549        num_padded (int, optional): Number of padding samples. Dataset size
3550            plus num_padded should be divisible by num_shards.
3551        num_samples (int, optional): The number of samples to be included in the dataset
3552            (default=None, all samples).
3553        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
3554            (default=None, which means no cache is used).
3555
3556    Raises:
3557        RuntimeError: If dataset_files are not valid or do not exist.
3558        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
3559        RuntimeError: If num_shards is specified but shard_id is None.
3560        RuntimeError: If shard_id is specified but num_shards is None.
3561        ValueError: If shard_id is invalid (< 0 or >= num_shards).
3562
3563    Note:
3564        - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
3565          The table below shows what input arguments are allowed and their expected behavior.
3566
3567    .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
3568       :widths: 25 25 50
3569       :header-rows: 1
3570
3571       * - Parameter `sampler`
3572         - Parameter `shuffle`
3573         - Expected Order Behavior
3574       * - None
3575         - None
3576         - random order
3577       * - None
3578         - True
3579         - random order
3580       * - None
3581         - False
3582         - sequential order
3583       * - Sampler object
3584         - None
3585         - order defined by sampler
3586       * - Sampler object
3587         - True
3588         - not allowed
3589       * - Sampler object
3590         - False
3591         - not allowed
3592
3593    Examples:
3594        >>> mind_dataset_dir = ["/path/to/mind_dataset_file"] # contains 1 or multiple MindRecord files
3595        >>> dataset = ds.MindDataset(dataset_file=mind_dataset_dir)
3596    """
3597
3598    def parse(self, children=None):
3599        return cde.MindDataNode(self.dataset_file, self.columns_list, self.sampler, self.new_padded_sample,
3600                                self.num_padded, shuffle_to_shuffle_mode(self.shuffle_option))
3601
3602    @check_minddataset
3603    def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None, shuffle=None, num_shards=None,
3604                 shard_id=None, sampler=None, padded_sample=None, num_padded=None, num_samples=None, cache=None):
3605        if shuffle is not None and not isinstance(shuffle, (bool, Shuffle)):
3606            raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like 'Shuffle.GLOBAL' or "
3607                            "'Shuffle.FILES' or 'Shuffle.INFILE'.")
3608        self.shuffle_option = shuffle
3609        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
3610                         shuffle=shuffle_to_bool(shuffle), num_shards=num_shards, shard_id=shard_id, cache=cache)
3611        if isinstance(dataset_file, list):
3612            self.load_dataset = False
3613        else:
3614            self.load_dataset = True
3615        self.dataset_file = dataset_file
3616        self.columns_list = replace_none(columns_list, [])
3617
3618        if shuffle is False:
3619            logger.warning("WARN: global shuffle is not used.")
3620
3621        if sampler is not None:
3622            if isinstance(sampler, (
3623                    samplers.SubsetRandomSampler, samplers.SubsetSampler, samplers.PKSampler,
3624                    samplers.DistributedSampler,
3625                    samplers.RandomSampler, samplers.SequentialSampler)) is False:
3626                raise ValueError("The sampler is not supported yet.")
3627
3628        self.padded_sample = padded_sample
3629        self.num_padded = replace_none(num_padded, 0)
3630
3631        self.new_padded_sample = {}
3632        if padded_sample:
3633            for k, v in padded_sample.items():
3634                if isinstance(v, np.ndarray):
3635                    self.new_padded_sample[k] = v.tobytes()
3636                else:
3637                    self.new_padded_sample[k] = v
3638
3639
3640def _iter_fn(dataset, num_samples):
3641    """
3642    Generator function wrapper for iterable dataset.
3643    """
3644    if num_samples is not None and num_samples != 0:
3645        ds_iter = iter(dataset)
3646        for _ in range(num_samples):
3647            try:
3648                val = next(ds_iter)
3649            except StopIteration:
3650                return
3651            # convert output tensors to ndarrays
3652            yield _convert_row(val)
3653    else:
3654        for val in dataset:
3655            # convert output tensors to ndarrays
3656            yield _convert_row(val)
3657
3658
3659def _generator_fn(generator, num_samples):
3660    """
3661    Generator function wrapper for generator function dataset.
3662    """
3663    if num_samples is not None and num_samples != 0:
3664        gen_iter = generator()
3665        for _ in range(num_samples):
3666            try:
3667                val = next(gen_iter)
3668            except StopIteration:
3669                return
3670            yield val
3671    else:
3672        gen_iter = generator()
3673        for val in gen_iter:
3674            yield val
3675
3676
3677def _cpp_sampler_fn(sample_ids, dataset):
3678    """
3679    Generator function wrapper for mappable dataset with cpp sampler.
3680    """
3681    if not isinstance(sample_ids, np.ndarray):
3682        raise RuntimeError("Sample IDs are not in a numpy array.")
3683    if sample_ids.size == 0:
3684        raise RuntimeError("Sampler passed an empty sample IDs list.")
3685
3686    for i in sample_ids:
3687        val = dataset[i]
3688        # convert output tensors to ndarrays
3689        yield _convert_row(val)
3690
3691
3692def _cpp_sampler_fn_mp(sample_ids, sample_fn):
3693    """
3694    Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
3695    """
3696    if not isinstance(sample_ids, np.ndarray):
3697        raise RuntimeError("Sample IDs are not in a numpy array.")
3698    if sample_ids.size == 0:
3699        raise RuntimeError("Sampler passed an empty sample IDs list.")
3700
3701    return sample_fn.process(sample_ids)
3702
3703
3704def _fill_worker_indices(workers, indices, idx):
3705    """
3706    Worker index queue filler, fill worker index queue in round robin order.
3707    """
3708    num_worker = len(workers)
3709    while idx < len(indices):
3710        try:
3711            workers[idx % num_worker].put(indices[idx])
3712            idx += 1
3713        except queue.Full:
3714            break
3715    return idx
3716
3717
3718def _check_shm_usage(num_worker, queue_size, max_rowsize, num_queues=1):
3719    """
3720    Check sufficient shared memory is available for shared memory queues
3721    when training in parallel mode.
3722    """
3723    threshold_ratio = 0.8
3724    if platform.system() != "Windows" and _get_device_num() >= 1:
3725        shm_estimate_usage = _get_device_num() * num_worker * num_queues * \
3726            (queue_size + 2) * max_rowsize * 1024 * 1024
3727        try:
3728            shm_available = psutil.disk_usage('/dev/shm').free
3729            if shm_estimate_usage >= threshold_ratio * shm_available:
3730                raise RuntimeError(
3731                    "Insufficient shared memory available. Required: {}, Available: {}. "
3732                    "The required memory can't exceed 80% of the available shared memory. "
3733                    "Recommend to set_enable_shared_mem to False, reduce max_rowsize or reduce num_parallel_workers."
3734                    .format(shm_estimate_usage, shm_available))
3735        except FileNotFoundError:
3736            raise RuntimeError("Expected /dev/shm to exist.")
3737
3738
3739def _convert_row(row):
3740    """
3741    Convert Op return value to numpy
3742    """
3743    value = []
3744    # convert each column in row into numpy array
3745    for x in row:
3746        if isinstance(x, bytes):         # got image bytes from a file
3747            value.append(np.frombuffer(x, np.uint8))
3748        elif isinstance(x, Tensor):      # got mindspore.Tensor
3749            value.append(x.asnumpy())
3750        else:
3751            value.append(np.array(x, copy=False))
3752    return tuple(value)
3753
3754
3755class SamplerFn:
3756    """
3757    Multiprocessing or multithread generator function wrapper master process.
3758    """
3759
3760    def __init__(self, dataset, num_worker, multi_process, max_rowsize):
3761        self.workers = []
3762        self.num_worker = num_worker
3763        self.multi_process = multi_process
3764        self.need_join = False
3765        self.ppid = os.getpid()
3766        self.pids = []
3767        # Event for end of epoch
3768        if multi_process is True:
3769            try:
3770                self.eof = multiprocessing.Event()
3771            except Exception:
3772                raise RuntimeError("Init multiprocessing.Event() failed, This might be caused by insufficient shm,"
3773                                   + " and the recommended shm size is at least 5 GB.")
3774        else:
3775            self.eof = threading.Event()
3776        # Create workers
3777
3778        # get default queue size and adjust queuesize per worker if there are large # workers
3779        queue_size = get_prefetch_size()
3780        queue_size = min(queue_size, queue_size * 4 // num_worker)
3781        queue_size = max(2, queue_size)
3782
3783        if multi_process and get_enable_shared_mem():
3784            _check_shm_usage(num_worker, queue_size, max_rowsize)
3785        for _ in range(num_worker):
3786            if multi_process is True:
3787                try:
3788                    worker = _GeneratorWorkerMp(dataset, self.eof, max_rowsize, queue_size)
3789                except Exception:
3790                    raise RuntimeError("Init multiprocessing.Queue() failed, This might be caused by insufficient shm,"
3791                                       + " and the recommended shm size is at least 5 GB.")
3792                worker.daemon = True
3793                # When multi processes fork a subprocess, the lock of the main process is copied to the subprocess,
3794                # which may cause deadlock. Therefore, the subprocess startup is performed in che initialization phase.
3795                # In this phase, the main process is not locked.
3796                worker.start()
3797                self.pids.append(worker.pid)
3798                self.need_join = True
3799            else:
3800                worker = _GeneratorWorkerMt(dataset, self.eof)
3801                worker.daemon = True
3802            self.workers.append(worker)
3803        if multi_process is True and platform.system().lower() != 'windows':
3804            self.eot = threading.Event()
3805            self.watch_dog = threading.Thread(target=_watch_dog, args=(self.eot, self.pids))
3806            self.watch_dog.daemon = True
3807            self.watch_dog.start()
3808
3809    def process(self, indices):
3810        """
3811        The main process, start the child process or child thread, and fill the index queue.
3812        Get the result and return.
3813        """
3814        for w in self.workers:
3815            # Check whether the queue of the subprocess is empty.
3816            if not w.queue_empty():
3817                raise Exception("The queue of the subprocess is not empty.")
3818            # Start all workers
3819            if not w.is_alive():
3820                w.start()
3821
3822        # Fill initial index queues
3823        idx_cursor = 0
3824        idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
3825
3826        # Fetch results
3827        for i in range(len(indices)):
3828            if self.eof.is_set():
3829                self._stop_subprocess()
3830                return
3831            if self.multi_process is True and not psutil.pid_exists(self.workers[i % self.num_worker].pid):
3832                self._stop_subprocess()
3833                return
3834            # Fetch result and put index
3835            try:
3836                result = self.workers[i % self.num_worker].get()
3837                if isinstance(result, ExceptionHandler):
3838                    result.reraise()
3839            except queue.Empty:
3840                self._stop_subprocess()
3841                raise Exception("Generator worker process timeout.")
3842            except KeyboardInterrupt:
3843                self._stop_subprocess()
3844                raise Exception("Generator worker receives KeyboardInterrupt.")
3845            if self.eof.is_set():
3846                self._stop_subprocess()
3847                return
3848            if idx_cursor < len(indices):
3849                idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
3850            yield _convert_row(result)
3851
3852    def _stop_subprocess(self):
3853        # Only the main process can call join
3854        if self.need_join is True and self.ppid == os.getpid():
3855            self.eof.set()
3856            self.need_join = False
3857            for w in self.workers:
3858                if psutil.pid_exists(w.pid):
3859                    w.join()
3860            self._abort_watchdog()
3861
3862    def _abort_watchdog(self):
3863        if hasattr(self, 'eot') and self.eot is not None and not self.eot.is_set():
3864            self.eot.set()
3865
3866    def __del__(self):
3867        self._stop_subprocess()
3868
3869
3870def _subprocess_handle(eof, signum, frame):
3871    threading.Thread(target=eof.set()).start()
3872
3873
3874def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing):
3875    """
3876    Multithread or multiprocess generator worker process loop.
3877    """
3878    if is_multiprocessing:
3879        signal.signal(signal.SIGTERM, partial(_subprocess_handle, eof))
3880    while True:
3881        # Fetch index, block
3882        try:
3883            idx = idx_queue.get(timeout=1)
3884        except KeyboardInterrupt:
3885            if is_multiprocessing:
3886                eof.set()
3887                idx_queue.cancel_join_thread()
3888                result_queue.cancel_join_thread()
3889            raise Exception("Generator worker receives KeyboardInterrupt.")
3890        except queue.Empty:
3891            if eof.is_set():
3892                if is_multiprocessing:
3893                    idx_queue.cancel_join_thread()
3894                    result_queue.cancel_join_thread()
3895                return
3896            # If end-of-file (eof) is not set, continue to get data from idx_queue
3897            continue
3898        if idx is None:
3899            # When the queue is out of scope from master process, a None item can be fetched from the queue.
3900            # Upon receiving None, worker process should check if eof is set.
3901            if not eof.is_set():
3902                raise Exception("")
3903            return
3904        if eof.is_set():
3905            if is_multiprocessing:
3906                idx_queue.cancel_join_thread()
3907                result_queue.cancel_join_thread()
3908            return
3909        # Fetch data, any exception from __getitem__ will terminate worker and timeout master process
3910        try:
3911            result = dataset[idx]
3912        except Exception:
3913            result = ExceptionHandler(where="in GeneratorDataset worker process")
3914        # Send data, block
3915        while True:
3916            try:
3917                result_queue.put(result, timeout=5)
3918            except KeyboardInterrupt:
3919                if is_multiprocessing:
3920                    eof.set()
3921                    idx_queue.cancel_join_thread()
3922                    result_queue.cancel_join_thread()
3923                raise Exception("Generator worker receives KeyboardInterrupt.")
3924            except queue.Full:
3925                if eof.is_set():
3926                    if is_multiprocessing:
3927                        idx_queue.cancel_join_thread()
3928                        result_queue.cancel_join_thread()
3929                    return
3930                # If eof is not set, continue to put data to result_queue
3931                continue
3932            break
3933        del result, idx
3934
3935
3936class _GeneratorWorkerMt(threading.Thread):
3937    """
3938    Worker process for multi-thread Generator.
3939    """
3940
3941    def __init__(self, dataset, eof):
3942        self.idx_queue = queue.Queue(16)
3943        self.res_queue = queue.Queue(16)
3944        super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, False))
3945
3946    def put(self, item):
3947        """
3948        Put function for worker index queue. Never block. Raise queue.Full on failure.
3949        """
3950        self.idx_queue.put_nowait(item)
3951
3952    def get(self):
3953        """
3954        Get function for worker result queue. Block with timeout.
3955        """
3956        return self.res_queue.get(timeout=30)
3957
3958    def queue_empty(self):
3959        if not self.idx_queue.empty():
3960            logger.warning("idx_queue is not empty")
3961            return False
3962        if not self.res_queue.empty():
3963            logger.warning("res_queue is not empty")
3964            return False
3965        return True
3966
3967
3968class _GeneratorWorkerMp(multiprocessing.Process):
3969    """
3970    Worker process for multiprocess Generator.
3971    """
3972
3973    def __init__(self, dataset, eof, max_rowsize, queue_size):
3974        self.idx_queue = multiprocessing.Queue(queue_size)
3975        if get_enable_shared_mem():
3976            self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize)
3977        else:
3978            self.res_queue = multiprocessing.Queue(queue_size)
3979        self.idx_queue._joincancelled = True  # pylint: disable=W0212
3980        self.res_queue._joincancelled = True  # pylint: disable=W0212
3981        super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True))
3982
3983    def put(self, item):
3984        """
3985        Put function for worker index queue. Never block. Raise queue.Full on failure.
3986        """
3987        self.idx_queue.put_nowait(item)
3988
3989    def get(self):
3990        """
3991        Get function for worker result queue. Block with timeout.
3992        """
3993        # Relax 10s to 30s, since it sometimes will cause "Generator worker process timeout"
3994        # when we run too many iterators with infinite epoch(num_epoch=-1)
3995        return self.res_queue.get(timeout=30)
3996
3997    def queue_empty(self):
3998        if not self.idx_queue.empty():
3999            logger.warning("idx_queue is not empty.")
4000            return False
4001        if not self.res_queue.empty():
4002            logger.warning("res_queue is not empty.")
4003            return False
4004        return True
4005
4006
4007class GeneratorDataset(MappableDataset):
4008    """
4009    A source dataset that generates data from Python by invoking Python data source each epoch.
4010
4011    The column names and column types of generated dataset depend on Python data defined by users.
4012
4013    Args:
4014        source (Union[Callable, Iterable, Random Accessible]):
4015            A generator callable object, an iterable Python object or a random accessible Python object.
4016            Callable source is required to return a tuple of NumPy arrays as a row of the dataset on source().next().
4017            Iterable source is required to return a tuple of NumPy arrays as a row of the dataset on
4018            iter(source).next().
4019            Random accessible source is required to return a tuple of NumPy arrays as a row of the dataset on
4020            source[idx].
4021        column_names (Union[str, list[str]], optional): List of column names of the dataset (default=None). Users are
4022            required to provide either column_names or schema.
4023        column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
4024            If provided, sanity check will be performed on generator output.
4025        schema (Union[Schema, str], optional): Path to the JSON schema file or schema object (default=None). Users are
4026            required to provide either column_names or schema. If both are provided, schema will be used.
4027        num_samples (int, optional): The number of samples to be included in the dataset
4028            (default=None, all images).
4029        num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
4030        shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
4031            (default=None, expected order behavior shown in the table).
4032        sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. Random accessible
4033            input is required (default=None, expected order behavior shown in the table).
4034        num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
4035            Random accessible input is required. When this argument is specified, `num_samples` reflects the maximum
4036            sample number of per shard.
4037        shard_id (int, optional): The shard ID within num_shards (default=None). This argument must be specified only
4038            when num_shards is also specified. Random accessible input is required.
4039        python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
4040            option could be beneficial if the Python operation is computational heavy (default=True).
4041        max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy
4042            data between processes.  This is only used if python_multiprocessing is set to True (default 6 MB).
4043
4044    Raises:
4045        RuntimeError: If source raises an exception during execution.
4046        RuntimeError: If len of column_names does not match output len of source.
4047        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
4048        RuntimeError: If sampler and shuffle are specified at the same time.
4049        RuntimeError: If sampler and sharding are specified at the same time.
4050        RuntimeError: If num_shards is specified but shard_id is None.
4051        RuntimeError: If shard_id is specified but num_shards is None.
4052        ValueError: If shard_id is invalid (< 0 or >= num_shards).
4053
4054    Note:
4055        - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
4056          The table below shows what input arguments are allowed and their expected behavior.
4057
4058    .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
4059       :widths: 25 25 50
4060       :header-rows: 1
4061
4062       * - Parameter `sampler`
4063         - Parameter `shuffle`
4064         - Expected Order Behavior
4065       * - None
4066         - None
4067         - random order
4068       * - None
4069         - True
4070         - random order
4071       * - None
4072         - False
4073         - sequential order
4074       * - Sampler object
4075         - None
4076         - order defined by sampler
4077       * - Sampler object
4078         - True
4079         - not allowed
4080       * - Sampler object
4081         - False
4082         - not allowed
4083
4084    Examples:
4085        >>> import numpy as np
4086        >>>
4087        >>> # 1) Multidimensional generator function as callable input.
4088        >>> def generator_multidimensional():
4089        ...     for i in range(64):
4090        ...         yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
4091        >>>
4092        >>> dataset = ds.GeneratorDataset(source=generator_multidimensional, column_names=["multi_dimensional_data"])
4093        >>>
4094        >>> # 2) Multi-column generator function as callable input.
4095        >>> def generator_multi_column():
4096        ...     for i in range(64):
4097        ...         yield np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])
4098        >>>
4099        >>> dataset = ds.GeneratorDataset(source=generator_multi_column, column_names=["col1", "col2"])
4100        >>>
4101        >>> # 3) Iterable dataset as iterable input.
4102        >>> class MyIterable:
4103        ...     def __init__(self):
4104        ...         self._index = 0
4105        ...         self._data = np.random.sample((5, 2))
4106        ...         self._label = np.random.sample((5, 1))
4107        ...
4108        ...     def __next__(self):
4109        ...         if self._index >= len(self._data):
4110        ...             raise StopIteration
4111        ...         else:
4112        ...             item = (self._data[self._index], self._label[self._index])
4113        ...             self._index += 1
4114        ...             return item
4115        ...
4116        ...     def __iter__(self):
4117        ...         self._index = 0
4118        ...         return self
4119        ...
4120        ...     def __len__(self):
4121        ...         return len(self._data)
4122        >>>
4123        >>> dataset = ds.GeneratorDataset(source=MyIterable(), column_names=["data", "label"])
4124        >>>
4125        >>> # 4) Random accessible dataset as random accessible input.
4126        >>> class MyAccessible:
4127        ...     def __init__(self):
4128        ...         self._data = np.random.sample((5, 2))
4129        ...         self._label = np.random.sample((5, 1))
4130        ...
4131        ...     def __getitem__(self, index):
4132        ...         return self._data[index], self._label[index]
4133        ...
4134        ...     def __len__(self):
4135        ...         return len(self._data)
4136        >>>
4137        >>> dataset = ds.GeneratorDataset(source=MyAccessible(), column_names=["data", "label"])
4138        >>>
4139        >>> # list, dict, tuple of Python is also random accessible
4140        >>> dataset = ds.GeneratorDataset(source=[(np.array(0),), (np.array(1),), (np.array(2),)], column_names=["col"])
4141    """
4142
4143    @check_generatordataset
4144    def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
4145                 num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
4146                 python_multiprocessing=True, max_rowsize=6):
4147        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
4148                         shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
4149        self.source = source
4150        self.prepared_source = None  # source to be sent to C++
4151
4152        self.python_multiprocessing = python_multiprocessing
4153
4154        self.column_names = to_list(column_names)
4155
4156        if column_types is not None:
4157            self.column_types = mstypelist_to_detypelist(column_types)
4158        else:
4159            self.column_types = []
4160
4161        self.schema = schema
4162        if schema is not None:
4163            self.schema = schema
4164            if not isinstance(schema, Schema):
4165                self.schema = Schema(schema)
4166        # Move get dataset_size by len from parse to here, because self.source will
4167        # lose attribution of '__len__' after deepcopy.
4168        self.source_len = -1  # unknown
4169        if hasattr(self.source, "__len__"):
4170            self.source_len = len(self.source)
4171
4172        self.max_rowsize = max_rowsize
4173        self.sample_fn = None
4174
4175    def __deepcopy__(self, memodict):
4176        if id(self) in memodict:
4177            return memodict[id(self)]
4178        new_op = self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__"))
4179
4180        sample_fn = None
4181        if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
4182            # The reason why there is a try catch here is because when the new op is being constructed with shared
4183            # memory enabled, there will be an exception thrown if there is not enough shared memory available
4184            if self.source_len == -1:
4185                raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!")
4186            try:
4187                if new_op.num_parallel_workers > 1:
4188                    sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing,
4189                                          self.max_rowsize)
4190                    new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
4191                else:
4192                    new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
4193                new_op.sample_fn = sample_fn
4194            except RuntimeError as e:
4195                raise Exception(str(e))
4196        else:
4197            try:
4198                new_op.sampler = None
4199                new_op.sample_fn = sample_fn
4200                new_op.source_len = min(new_op.source_len,
4201                                        new_op.num_samples) if new_op.num_samples != 0 else new_op.source_len
4202                iter(self.source)
4203            except TypeError:
4204                # Use generator function if input callable
4205                new_op.prepared_source = (lambda: _generator_fn(self.source, new_op.num_samples))
4206            else:
4207                # Use iterator function if input is iterable
4208                # Random accessible input is also iterable
4209                new_op.prepared_source = (lambda: _iter_fn(self.source, new_op.num_samples))
4210
4211        return new_op
4212
4213    def is_shuffled(self):
4214        return self.sampler.is_shuffled()
4215
4216    def is_sharded(self):
4217        return self.sampler.is_sharded()
4218
4219    def parse(self, children=None):
4220        if self.schema is None:
4221            return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
4222                                     self.sampler, self.num_parallel_workers)
4223        schema = self.schema
4224        if isinstance(schema, Schema):
4225            schema = self.schema.cpp_schema
4226        return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler,
4227                                 self.num_parallel_workers)
4228
4229
4230class TFRecordDataset(SourceDataset):
4231    """
4232    A source dataset for reading and parsing datasets stored on disk in TFData format.
4233
4234    The columns of generated dataset depend on the source TFRecord files.
4235
4236    Args:
4237        dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a
4238            pattern of files. The list will be sorted in a lexicographical order.
4239        schema (Union[str, Schema], optional): Path to the JSON schema file or schema object (default=None).
4240            If the schema is not provided, the meta data from the TFData file is considered the schema.
4241        columns_list (list[str], optional): List of columns to be read (default=None, read all columns).
4242        num_samples (int, optional): The number of samples (rows) to be included in the dataset (default=None).
4243            If num_samples is None and numRows(parsed from schema) does not exist, read the full dataset;
4244            If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows;
4245            If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
4246        num_parallel_workers (int, optional): Number of workers to read the data
4247            (default=None, number set in the config).
4248        shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
4249            (default=Shuffle.GLOBAL).
4250            If shuffle is False, no shuffling will be performed;
4251            If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
4252            Otherwise, there are two levels of shuffling:
4253
4254            - Shuffle.GLOBAL: Shuffle both the files and samples.
4255
4256            - Shuffle.FILES: Shuffle files only.
4257
4258        num_shards (int, optional): Number of shards that the dataset will be divided
4259            into (default=None). When this argument is specified, `num_samples` reflects
4260            the maximum sample number of per shard.
4261        shard_id (int, optional): The shard ID within num_shards (default=None). This
4262            argument can only be specified when num_shards is also specified.
4263        shard_equal_rows (bool, optional): Get equal rows for all shards(default=False). If shard_equal_rows
4264            is false, number of rows of each shard may be not equal, and may lead to a failure in distributed training.
4265            When the number of samples of per TFRecord file are not equal, it is suggested to set to true.
4266            This argument should only be specified when num_shards is also specified.
4267        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
4268            (default=None, which means no cache is used).
4269
4270    Raises:
4271        RuntimeError: If dataset_files are not valid or do not exist.
4272        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
4273        RuntimeError: If num_shards is specified but shard_id is None.
4274        RuntimeError: If shard_id is specified but num_shards is None.
4275        ValueError: If shard_id is invalid (< 0 or >= num_shards).
4276
4277    Examples:
4278        >>> from mindspore import dtype as mstype
4279        >>>
4280        >>> tfrecord_dataset_dir = ["/path/to/tfrecord_dataset_file"] # contains 1 or multiple TFRecord files
4281        >>> tfrecord_schema_file = "/path/to/tfrecord_schema_file"
4282        >>>
4283        >>> # 1) Get all rows from tfrecord_dataset_dir with no explicit schema.
4284        >>> # The meta-data in the first row will be used as a schema.
4285        >>> dataset = ds.TFRecordDataset(dataset_files=tfrecord_dataset_dir)
4286        >>>
4287        >>> # 2) Get all rows from tfrecord_dataset_dir with user-defined schema.
4288        >>> schema = ds.Schema()
4289        >>> schema.add_column(name='col_1d', de_type=mstype.int64, shape=[2])
4290        >>> dataset = ds.TFRecordDataset(dataset_files=tfrecord_dataset_dir, schema=schema)
4291        >>>
4292        >>> # 3) Get all rows from tfrecord_dataset_dir with schema file.
4293        >>> dataset = ds.TFRecordDataset(dataset_files=tfrecord_dataset_dir, schema=tfrecord_schema_file)
4294    """
4295
4296    @check_tfrecorddataset
4297    def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
4298                 shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None):
4299        super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
4300                         num_shards=num_shards, shard_id=shard_id, cache=cache)
4301        self.dataset_files = self._find_files(dataset_files)
4302        self.dataset_files.sort()
4303
4304        self.schema = schema
4305        self.columns_list = replace_none(columns_list, [])
4306        self.shard_equal_rows = replace_none(shard_equal_rows, False)
4307
4308        if self.schema is not None and (self.num_samples is None or self.num_samples == 0):
4309            self.num_samples = Schema.get_num_rows(self.schema)
4310
4311    def parse(self, children=None):
4312        schema = self.schema.cpp_schema if isinstance(self.schema, Schema) else self.schema
4313        return cde.TFRecordNode(self.dataset_files, schema, self.columns_list, self.num_samples, self.shuffle_flag,
4314                                self.num_shards, self.shard_id, self.shard_equal_rows)
4315
4316
4317class ManifestDataset(MappableDataset):
4318    """
4319    A source dataset for reading images from a Manifest file.
4320
4321    The generated dataset has two columns: :py:obj:`[image, label]`.
4322    The tensor of column :py:obj:`image` is of the uint8 type.
4323    The tensor of column :py:obj:`label` is of a scalar of uint64 type.
4324
4325    Args:
4326        dataset_file (str): File to be read.
4327        usage (str, optional): Acceptable usages include `train`, `eval` and `inference` (default=`train`).
4328        num_samples (int, optional): The number of images to be included in the dataset.
4329            (default=None, will include all images).
4330        num_parallel_workers (int, optional): Number of workers to read the data
4331            (default=None, will use value set in the config).
4332        shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
4333            order behavior shown in the table).
4334        sampler (Sampler, optional): Object used to choose samples from the
4335            dataset (default=None, expected order behavior shown in the table).
4336        class_indexing (dict, optional): A str-to-int mapping from label name to index
4337            (default=None, the folder names will be sorted alphabetically and each
4338            class will be given a unique index starting from 0).
4339        decode (bool, optional): decode the images after reading (default=False).
4340        num_shards (int, optional): Number of shards that the dataset will be divided
4341            into (default=None). When this argument is specified, `num_samples` reflects
4342            the max number of samples per shard.
4343        shard_id (int, optional): The shard ID within `num_shards` (default=None). This
4344            argument can only be specified when `num_shards` is also specified.
4345        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
4346            (default=None, which means no cache is used).
4347
4348    Raises:
4349        RuntimeError: If dataset_files are not valid or do not exist.
4350        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
4351        RuntimeError: If sampler and shuffle are specified at the same time.
4352        RuntimeError: If sampler and sharding are specified at the same time.
4353        RuntimeError: If num_shards is specified but shard_id is None.
4354        RuntimeError: If shard_id is specified but num_shards is None.
4355        RuntimeError: If class_indexing is not a dictionary.
4356        ValueError: If shard_id is invalid (< 0 or >= num_shards).
4357
4358    Note:
4359        - The shape of the image column is [image_size] if decode flag is False, or [H,W,C] otherwise.
4360        - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
4361          The table below shows what input arguments are allowed and their expected behavior.
4362
4363    .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
4364       :widths: 25 25 50
4365       :header-rows: 1
4366
4367       * - Parameter `sampler`
4368         - Parameter `shuffle`
4369         - Expected Order Behavior
4370       * - None
4371         - None
4372         - random order
4373       * - None
4374         - True
4375         - random order
4376       * - None
4377         - False
4378         - sequential order
4379       * - Sampler object
4380         - None
4381         - order defined by sampler
4382       * - Sampler object
4383         - True
4384         - not allowed
4385       * - Sampler object
4386         - False
4387         - not allowed
4388
4389    Examples:
4390        >>> manifest_dataset_dir = "/path/to/manifest_dataset_file"
4391        >>>
4392        >>> # 1) Read all samples specified in manifest_dataset_dir dataset with 8 threads for training
4393        >>> dataset = ds.ManifestDataset(dataset_file=manifest_dataset_dir, usage="train", num_parallel_workers=8)
4394        >>>
4395        >>> # 2) Read samples (specified in manifest_file.manifest) for shard 0 in a 2-way distributed training setup
4396        >>> dataset = ds.ManifestDataset(dataset_file=manifest_dataset_dir, num_shards=2, shard_id=0)
4397    """
4398
4399    @check_manifestdataset
4400    def __init__(self, dataset_file, usage="train", num_samples=None, num_parallel_workers=None, shuffle=None,
4401                 sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None, cache=None):
4402        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
4403                         shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
4404
4405        self.dataset_file = dataset_file
4406        self.decode = replace_none(decode, False)
4407        self.usage = replace_none(usage, "train")
4408        self.class_indexing = replace_none(class_indexing, {})
4409
4410    def parse(self, children=None):
4411        return cde.ManifestNode(self.dataset_file, self.usage, self.sampler, self.class_indexing, self.decode)
4412
4413    def get_class_indexing(self):
4414        """
4415        Get the class index.
4416
4417        Returns:
4418            dict, a str-to-int mapping from label name to index.
4419
4420        Examples:
4421            >>> manifest_dataset_dir = "/path/to/manifest_dataset_file"
4422            >>>
4423            >>> dataset = ds.ManifestDataset(dataset_file=manifest_dataset_dir)
4424            >>> class_indexing = dataset.get_class_indexing()
4425        """
4426        if self.class_indexing is None or not self.class_indexing:
4427            if self._class_indexing is None:
4428                runtime_getter = self._init_tree_getters()
4429                self._class_indexing = runtime_getter[0].GetClassIndexing()
4430            self.class_indexing = {}
4431            for pair in self._class_indexing:
4432                self.class_indexing[pair[0]] = pair[1][0]
4433        return self.class_indexing
4434
4435
4436class Cifar10Dataset(MappableDataset):
4437    """
4438    A source dataset for reading and parsing Cifar10 dataset.
4439    This api only supports parsing Cifar10 file in binary version now.
4440
4441    The generated dataset has two columns :py:obj:`[image, label]`.
4442    The tensor of column :py:obj:`image` is of the uint8 type.
4443    The tensor of column :py:obj:`label` is a scalar of the uint32 type.
4444
4445    Args:
4446        dataset_dir (str): Path to the root directory that contains the dataset.
4447        usage (str, optional): Usage of this dataset, can be `train`, `test` or `all` . `train` will read from 50,000
4448            train samples, `test` will read from 10,000 test samples, `all` will read from all 60,000 samples
4449            (default=None, all samples).
4450        num_samples (int, optional): The number of images to be included in the dataset
4451            (default=None, all images).
4452        num_parallel_workers (int, optional): Number of workers to read the data
4453            (default=None, number set in the config).
4454        shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
4455            order behavior shown in the table).
4456        sampler (Sampler, optional): Object used to choose samples from the
4457            dataset (default=None, expected order behavior shown in the table).
4458        num_shards (int, optional): Number of shards that the dataset will be divided
4459            into (default=None). When this argument is specified, `num_samples` reflects
4460            the maximum sample number of per shard.
4461        shard_id (int, optional): The shard ID within num_shards (default=None). This
4462            argument can only be specified when num_shards is also specified.
4463        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
4464            (default=None, which means no cache is used).
4465
4466    Raises:
4467        RuntimeError: If dataset_dir does not contain data files.
4468        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
4469        RuntimeError: If sampler and shuffle are specified at the same time.
4470        RuntimeError: If sampler and sharding are specified at the same time.
4471        RuntimeError: If num_shards is specified but shard_id is None.
4472        RuntimeError: If shard_id is specified but num_shards is None.
4473        ValueError: If shard_id is invalid (< 0 or >= num_shards).
4474
4475    Note:
4476        - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
4477          The table below shows what input arguments are allowed and their expected behavior.
4478
4479    .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
4480       :widths: 25 25 50
4481       :header-rows: 1
4482
4483       * - Parameter `sampler`
4484         - Parameter `shuffle`
4485         - Expected Order Behavior
4486       * - None
4487         - None
4488         - random order
4489       * - None
4490         - True
4491         - random order
4492       * - None
4493         - False
4494         - sequential order
4495       * - Sampler object
4496         - None
4497         - order defined by sampler
4498       * - Sampler object
4499         - True
4500         - not allowed
4501       * - Sampler object
4502         - False
4503         - not allowed
4504
4505    Examples:
4506        >>> cifar10_dataset_dir = "/path/to/cifar10_dataset_directory"
4507        >>>
4508        >>> # 1) Get all samples from CIFAR10 dataset in sequence
4509        >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, shuffle=False)
4510        >>>
4511        >>> # 2) Randomly select 350 samples from CIFAR10 dataset
4512        >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, num_samples=350, shuffle=True)
4513        >>>
4514        >>> # 3) Get samples from CIFAR10 dataset for shard 0 in a 2-way distributed training
4515        >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, num_shards=2, shard_id=0)
4516        >>>
4517        >>> # In CIFAR10 dataset, each dictionary has keys "image" and "label"
4518
4519    About CIFAR-10 dataset:
4520
4521    The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes,
4522    with 6000 images per class. There are 50000 training images and 10000 test images.
4523    The 10 different classes represent airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks.
4524
4525    Here is the original CIFAR-10 dataset structure.
4526    You can unzip the dataset files into the following directory structure and read by MindSpore's API.
4527
4528    .. code-block::
4529
4530        .
4531        └── cifar-10-batches-bin
4532             ├── data_batch_1.bin
4533             ├── data_batch_2.bin
4534             ├── data_batch_3.bin
4535             ├── data_batch_4.bin
4536             ├── data_batch_5.bin
4537             ├── test_batch.bin
4538             ├── readme.html
4539             └── batches.meta.txt
4540
4541    Citation:
4542
4543    .. code-block::
4544
4545        @techreport{Krizhevsky09,
4546        author       = {Alex Krizhevsky},
4547        title        = {Learning multiple layers of features from tiny images},
4548        institution  = {},
4549        year         = {2009},
4550        howpublished = {http://www.cs.toronto.edu/~kriz/cifar.html}
4551        }
4552    """
4553
4554    @check_mnist_cifar_dataset
4555    def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None,
4556                 sampler=None, num_shards=None, shard_id=None, cache=None):
4557        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
4558                         shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
4559
4560        self.dataset_dir = dataset_dir
4561        self.usage = replace_none(usage, "all")
4562
4563    def parse(self, children=None):
4564        return cde.Cifar10Node(self.dataset_dir, self.usage, self.sampler)
4565
4566
4567class Cifar100Dataset(MappableDataset):
4568    """
4569    A source dataset for reading and parsing Cifar100 dataset.
4570
4571    The generated dataset has three columns :py:obj:`[image, coarse_label, fine_label]`.
4572    The tensor of column :py:obj:`image` is of the uint8 type.
4573    The tensor of column :py:obj:`coarse_label` and :py:obj:`fine_labels` are each a scalar of uint32 type.
4574
4575    Args:
4576        dataset_dir (str): Path to the root directory that contains the dataset.
4577        usage (str, optional): Usage of this dataset, can be `train`, `test` or `all` . `train` will read from 50,000
4578            train samples, `test` will read from 10,000 test samples, `all` will read from all 60,000 samples
4579            (default=None, all samples).
4580        num_samples (int, optional): The number of images to be included in the dataset
4581            (default=None, all images).
4582        num_parallel_workers (int, optional): Number of workers to read the data
4583            (default=None, number set in the config).
4584        shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
4585            order behavior shown in the table).
4586        sampler (Sampler, optional): Object used to choose samples from the
4587            dataset (default=None, expected order behavior shown in the table).
4588        num_shards (int, optional): Number of shards that the dataset will be divided
4589            into (default=None). When this argument is specified, 'num_samples' reflects
4590            the maximum sample number of per shard.
4591        shard_id (int, optional): The shard ID within num_shards (default=None). This
4592            argument can only be specified when num_shards is also specified.
4593        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
4594            (default=None, which means no cache is used).
4595
4596    Raises:
4597        RuntimeError: If dataset_dir does not contain data files.
4598        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
4599        RuntimeError: If sampler and shuffle are specified at the same time.
4600        RuntimeError: If sampler and sharding are specified at the same time.
4601        RuntimeError: If num_shards is specified but shard_id is None.
4602        RuntimeError: If shard_id is specified but num_shards is None.
4603        ValueError: If shard_id is invalid (< 0 or >= num_shards).
4604
4605    Note:
4606        - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
4607          The table below shows what input arguments are allowed and their expected behavior.
4608
4609    .. list-table:: Expected Order Behavior of Using `sampler` and shuffle
4610       :widths: 25 25 50
4611       :header-rows: 1
4612
4613       * - Parameter `sampler`
4614         - Parameter `shuffle`
4615         - Expected Order Behavior
4616       * - None
4617         - None
4618         - random order
4619       * - None
4620         - True
4621         - random order
4622       * - None
4623         - False
4624         - sequential order
4625       * - Sampler object
4626         - None
4627         - order defined by sampler
4628       * - Sampler object
4629         - True
4630         - not allowed
4631       * - Sampler object
4632         - False
4633         - not allowed
4634
4635    Examples:
4636        >>> cifar100_dataset_dir = "/path/to/cifar100_dataset_directory"
4637        >>>
4638        >>> # 1) Get all samples from CIFAR100 dataset in sequence
4639        >>> dataset = ds.Cifar100Dataset(dataset_dir=cifar100_dataset_dir, shuffle=False)
4640        >>>
4641        >>> # 2) Randomly select 350 samples from CIFAR100 dataset
4642        >>> dataset = ds.Cifar100Dataset(dataset_dir=cifar100_dataset_dir, num_samples=350, shuffle=True)
4643        >>>
4644        >>> # In CIFAR100 dataset, each dictionary has 3 keys: "image", "fine_label" and "coarse_label"
4645
4646    About CIFAR-100 dataset:
4647
4648    This dataset is just like the CIFAR-10, except it has 100 classes containing 600 images
4649    each. There are 500 training images and 100 testing images per class. The 100 classes in
4650    the CIFAR-100 are grouped into 20 superclasses. Each image comes with a "fine" label (the
4651    class to which it belongs) and a "coarse" label (the superclass to which it belongs).
4652
4653    Here is the original CIFAR-100 dataset structure.
4654    You can unzip the dataset files into the following directory structure and read by MindSpore's API.
4655
4656    .. code-block::
4657
4658        .
4659        └── cifar-100-binary
4660            ├── train.bin
4661            ├── test.bin
4662            ├── fine_label_names.txt
4663            └── coarse_label_names.txt
4664
4665    Citation:
4666
4667    .. code-block::
4668
4669        @techreport{Krizhevsky09,
4670        author       = {Alex Krizhevsky},
4671        title        = {Learning multiple layers of features from tiny images},
4672        institution  = {},
4673        year         = {2009},
4674        howpublished = {http://www.cs.toronto.edu/~kriz/cifar.html}
4675        }
4676    """
4677
4678    @check_mnist_cifar_dataset
4679    def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None,
4680                 sampler=None, num_shards=None, shard_id=None, cache=None):
4681        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
4682                         shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
4683
4684        self.dataset_dir = dataset_dir
4685        self.usage = replace_none(usage, "all")
4686
4687    def parse(self, children=None):
4688        return cde.Cifar100Node(self.dataset_dir, self.usage, self.sampler)
4689
4690
4691class RandomDataset(SourceDataset):
4692    """
4693    A source dataset that generates random data.
4694
4695    Args:
4696        total_rows (int, optional): Number of samples for the dataset to generate
4697            (default=None, number of samples is random).
4698        schema (Union[str, Schema], optional): Path to the JSON schema file or schema object (default=None).
4699            If the schema is not provided, the random dataset generates a random schema.
4700        columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
4701        num_samples (int, optional): The number of samples to be included in the dataset
4702            (default=None, all samples).
4703        num_parallel_workers (int, optional): Number of workers to read the data
4704            (default=None, number set in the config).
4705        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
4706            (default=None, which means no cache is used).
4707        shuffle (bool, optional): Whether or not to perform shuffle on the dataset
4708            (default=None, expected order behavior shown in the table).
4709        num_shards (int, optional): Number of shards that the dataset will be divided
4710            into (default=None). When this argument is specified, 'num_samples' reflects
4711            the maximum sample number of per shard.
4712        shard_id (int, optional): The shard ID within num_shards (default=None). This
4713            argument can only be specified when num_shards is also specified.
4714    """
4715
4716    @check_random_dataset
4717    def __init__(self, total_rows=None, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
4718                 cache=None, shuffle=None, num_shards=None, shard_id=None):
4719        super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
4720                         num_shards=num_shards, shard_id=shard_id, cache=cache)
4721        self.total_rows = total_rows
4722        if schema is not None:
4723            self.total_rows = replace_none(total_rows, Schema.get_num_rows(schema))
4724        self.schema = schema
4725        self.columns_list = replace_none(columns_list, [])
4726
4727    def parse(self, children=None):
4728        schema = self.schema.cpp_schema if isinstance(self.schema, Schema) else self.schema
4729        return cde.RandomNode(self.total_rows, schema, self.columns_list)
4730
4731
4732class Schema:
4733    """
4734    Class to represent a schema of a dataset.
4735
4736    Args:
4737        schema_file(str): Path of the schema file (default=None).
4738
4739    Returns:
4740        Schema object, schema info about dataset.
4741
4742    Raises:
4743        RuntimeError: If schema file failed to load.
4744
4745    Examples:
4746        >>> from mindspore import dtype as mstype
4747        >>>
4748        >>> # Create schema; specify column name, mindspore.dtype and shape of the column
4749        >>> schema = ds.Schema()
4750        >>> schema.add_column(name='col1', de_type=mstype.int64, shape=[2])
4751    """
4752
4753    @check_schema
4754    def __init__(self, schema_file=None):
4755        self.schema_file = replace_none(schema_file, "")
4756        self.cpp_schema = cde.SchemaObj(self.schema_file)
4757
4758    @check_add_column
4759    def add_column(self, name, de_type, shape=None):
4760        """
4761        Add new column to the schema.
4762
4763        Args:
4764            name (str): The new name of the column.
4765            de_type (str): Data type of the column.
4766            shape (list[int], optional): Shape of the column
4767                (default=None, [-1] which is an unknown shape of rank 1).
4768
4769        Raises:
4770            ValueError: If column type is unknown.
4771        """
4772        if isinstance(de_type, typing.Type):
4773            de_type = mstype_to_detype(de_type)
4774            col_type = str(de_type)
4775        else:
4776            col_type = str(cde.DataType(de_type))
4777        if shape is None:
4778            self.cpp_schema.add_column(name, col_type)
4779        else:
4780            self.cpp_schema.add_column(name, col_type, shape)
4781
4782    def parse_columns(self, columns):
4783        """
4784        Parse the columns and add it to self.
4785
4786        Args:
4787            columns (Union[dict, list[dict], tuple[dict]]): Dataset attribute information, decoded from schema file.
4788
4789                - list[dict], 'name' and 'type' must be in keys, 'shape' optional.
4790
4791                - dict, columns.keys() as name, columns.values() is dict, and 'type' inside, 'shape' optional.
4792
4793        Raises:
4794            RuntimeError: If failed to parse columns.
4795            RuntimeError: If column's name field is missing.
4796            RuntimeError: If column's type field is missing.
4797
4798        Examples:
4799            >>> schema = Schema()
4800            >>> columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]},
4801            >>>             {'name': 'label', 'type': 'int8', 'shape': [1]}]
4802            >>> schema.parse_columns(columns1)
4803            >>> columns2 = {'image': {'shape': [3, 3], 'type': 'int8'}, 'label': {'shape': [1], 'type': 'int8'}}
4804            >>> schema.parse_columns(columns2)
4805        """
4806        self.cpp_schema.parse_columns(json.dumps(columns, indent=2))
4807
4808    def to_json(self):
4809        """
4810        Get a JSON string of the schema.
4811
4812        Returns:
4813            str, JSON string of the schema.
4814        """
4815        return self.cpp_schema.to_json()
4816
4817    def from_json(self, json_obj):
4818        """
4819        Get schema file from JSON object.
4820
4821        Args:
4822            json_obj(dictionary): Object of JSON parsed.
4823
4824        Raises:
4825            RuntimeError: if there is unknown item in the object.
4826            RuntimeError: if dataset type is missing in the object.
4827            RuntimeError: if columns are missing in the object.
4828        """
4829        self.cpp_schema.from_string(json.dumps(json_obj, indent=2))
4830
4831    def __str__(self):
4832        return self.to_json()
4833
4834    @staticmethod
4835    def get_num_rows(schema):
4836        schema_obj = schema
4837        if not isinstance(schema_obj, Schema):
4838            schema_obj = Schema(schema_obj)
4839        return schema_obj.cpp_schema.get_num_rows()
4840
4841
4842class USPSDataset(SourceDataset):
4843    """
4844    A source dataset for reading and parsing the USPS dataset.
4845
4846    The generated dataset has two columns: :py:obj:`[image, label]`.
4847    The tensor of column :py:obj:`image` is of the uint8 type.
4848    The tensor of column :py:obj:`label` is of a scalar of uint32 type.
4849
4850    Args:
4851        dataset_dir (str): Path to the root directory that contains the dataset.
4852        usage (str, optional): Usage of this dataset, can be "train", "test" or "all". "train" will read from 7,291
4853            train samples, "test" will read from 2,007 test samples, "all" will read from all 9,298 samples.
4854            (default=None, will read all samples)
4855        num_samples (int, optional): The number of images to be included in the dataset
4856            (default=None, will read all images).
4857        num_parallel_workers (int, optional): Number of workers to read the data
4858            (default=None, will use value set in the config).
4859        shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
4860            (default=Shuffle.GLOBAL).
4861            If shuffle is False, no shuffling will be performed;
4862            If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
4863            Otherwise, there are two levels of shuffling:
4864
4865            - Shuffle.GLOBAL: Shuffle both the files and samples.
4866
4867            - Shuffle.FILES: Shuffle files only.
4868
4869        num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
4870            When this argument is specified, `num_samples` reflects the max sample number of per shard.
4871        shard_id (int, optional): The shard ID within `num_shards` (default=None). This
4872            argument can only be specified when `num_shards` is also specified.
4873        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
4874            (default=None, which means no cache is used).
4875
4876    Raises:
4877        RuntimeError: If dataset_dir is not valid or does not exist or does not contain data files.
4878        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
4879        RuntimeError: If sampler and shuffle are specified at the same time.
4880        RuntimeError: If sampler and sharding are specified at the same time.
4881        RuntimeError: If num_shards is specified but shard_id is None.
4882        RuntimeError: If shard_id is specified but num_shards is None.
4883        ValueError: If usage is invalid.
4884        ValueError: If shard_id is invalid (< 0 or >= num_shards).
4885
4886    Examples:
4887        >>> usps_dataset_dir = "/path/to/usps_dataset_directory"
4888        >>>
4889        >>> # Read 3 samples from USPS dataset
4890        >>> dataset = ds.USPSDataset(dataset_dir=usps_dataset_dir, num_samples=3)
4891        >>>
4892        >>> # Note: In USPS dataset, each dictionary has keys "image" and "label"
4893
4894    About USPS dataset:
4895
4896    USPS is a digit dataset automatically scanned from envelopes by the U.S. Postal Service
4897    containing a total of 9,298 16×16 pixel grayscale samples.
4898    The images are centered, normalized and show a broad range of font styles.
4899
4900    Here is the original USPS dataset structure.
4901    You can download and unzip the dataset files into this directory structure and read by MindSpore's API.
4902
4903    .. code-block::
4904        .
4905        └── usps_dataset_dir
4906             ├── usps
4907             ├── usps.t
4908
4909    Citation:
4910
4911    .. code-block::
4912
4913        @article{hull1994database,
4914          title={A database for handwritten text recognition research},
4915          author={Hull, Jonathan J.},
4916          journal={IEEE Transactions on pattern analysis and machine intelligence},
4917          volume={16},
4918          number={5},
4919          pages={550--554},
4920          year={1994},
4921          publisher={IEEE}
4922        }
4923    """
4924
4925    @check_usps_dataset
4926    def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL,
4927                 num_shards=None, shard_id=None, cache=None):
4928        super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
4929                         num_shards=num_shards, shard_id=shard_id, cache=cache)
4930
4931        self.dataset_dir = dataset_dir
4932        self.usage = replace_none(usage, "all")
4933
4934    def parse(self, children=None):
4935        return cde.USPSNode(self.dataset_dir, self.usage, self.num_samples, self.shuffle_flag, self.num_shards,
4936                            self.shard_id)
4937
4938
4939class VOCDataset(MappableDataset):
4940    """
4941    A source dataset for reading and parsing VOC dataset.
4942
4943    The generated dataset with different task setting has different output columns:
4944
4945    - task = :py:obj:`Detection`, output columns: :py:obj:`[image, dtype=uint8]`, :py:obj:`[bbox, dtype=float32]`, \
4946        :py:obj:`[label, dtype=uint32]`, :py:obj:`[difficult, dtype=uint32]`, :py:obj:`[truncate, dtype=uint32]`.
4947    - task = :py:obj:`Segmentation`, output columns: :py:obj:`[image, dtype=uint8]`, :py:obj:`[target,dtype=uint8]`.
4948
4949    Args:
4950        dataset_dir (str): Path to the root directory that contains the dataset.
4951        task (str, optional): Set the task type of reading voc data, now only support `Segmentation` or `Detection`
4952            (default=`Segmentation`).
4953        usage (str, optional): Set the task type of ImageSets(default=`train`). If task is `Segmentation`, image and
4954            annotation list will be loaded in ./ImageSets/Segmentation/usage + ".txt"; If task is `Detection`, image and
4955            annotation list will be loaded in ./ImageSets/Main/usage + ".txt"; if task and usage are not set, image and
4956            annotation list will be loaded in ./ImageSets/Segmentation/train.txt as default.
4957        class_indexing (dict, optional): A str-to-int mapping from label name to index, only valid in
4958            `Detection` task (default=None, the folder names will be sorted alphabetically and each
4959            class will be given a unique index starting from 0).
4960        num_samples (int, optional): The number of images to be included in the dataset
4961            (default=None, all images).
4962        num_parallel_workers (int, optional): Number of workers to read the data
4963            (default=None, number set in the config).
4964        shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
4965            order behavior shown in the table).
4966        decode (bool, optional): Decode the images after reading (default=False).
4967        sampler (Sampler, optional): Object used to choose samples from the dataset
4968            (default=None, expected order behavior shown in the table).
4969        num_shards (int, optional): Number of shards that the dataset will be divided
4970            into (default=None). When this argument is specified, `num_samples` reflects
4971            the maximum sample number of per shard.
4972        shard_id (int, optional): The shard ID within num_shards (default=None). This
4973            argument can only be specified when num_shards is also specified.
4974        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
4975            (default=None, which means no cache is used).
4976        extra_metadata(bool, optional): Flag to add extra meta-data to row. If True, an additional column named
4977            :py:obj:`[_meta-filename, dtype=string]` will be output at the end (default=False).
4978
4979    Raises:
4980        RuntimeError: If dataset_dir does not contain data files.
4981        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
4982        RuntimeError: If xml of Annotations is an invalid format.
4983        RuntimeError: If xml of Annotations loss attribution of `object`.
4984        RuntimeError: If xml of Annotations loss attribution of `bndbox`.
4985        RuntimeError: If sampler and shuffle are specified at the same time.
4986        RuntimeError: If sampler and sharding are specified at the same time.
4987        RuntimeError: If num_shards is specified but shard_id is None.
4988        RuntimeError: If shard_id is specified but num_shards is None.
4989        ValueError: If task is not equal 'Segmentation' or 'Detection'.
4990        ValueError: If task equal 'Segmentation' but class_indexing is not None.
4991        ValueError: If txt related to mode is not exist.
4992        ValueError: If shard_id is invalid (< 0 or >= num_shards).
4993
4994    Note:
4995        - Column '[_meta-filename, dtype=string]' won't be output unless an explicit rename dataset op
4996          is added to remove the prefix('_meta-').
4997        - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
4998          The table below shows what input arguments are allowed and their expected behavior.
4999
5000    .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
5001       :widths: 25 25 50
5002       :header-rows: 1
5003
5004       * - Parameter `sampler`
5005         - Parameter `shuffle`
5006         - Expected Order Behavior
5007       * - None
5008         - None
5009         - random order
5010       * - None
5011         - True
5012         - random order
5013       * - None
5014         - False
5015         - sequential order
5016       * - Sampler object
5017         - None
5018         - order defined by sampler
5019       * - Sampler object
5020         - True
5021         - not allowed
5022       * - Sampler object
5023         - False
5024         - not allowed
5025
5026    Examples:
5027        >>> voc_dataset_dir = "/path/to/voc_dataset_directory"
5028        >>>
5029        >>> # 1) Read VOC data for segmentation training
5030        >>> dataset = ds.VOCDataset(dataset_dir=voc_dataset_dir, task="Segmentation", usage="train")
5031        >>>
5032        >>> # 2) Read VOC data for detection training
5033        >>> dataset = ds.VOCDataset(dataset_dir=voc_dataset_dir, task="Detection", usage="train")
5034        >>>
5035        >>> # 3) Read all VOC dataset samples in voc_dataset_dir with 8 threads in random order
5036        >>> dataset = ds.VOCDataset(dataset_dir=voc_dataset_dir, task="Detection", usage="train",
5037        ...                         num_parallel_workers=8)
5038        >>>
5039        >>> # 4) Read then decode all VOC dataset samples in voc_dataset_dir in sequence
5040        >>> dataset = ds.VOCDataset(dataset_dir=voc_dataset_dir, task="Detection", usage="train",
5041        ...                         decode=True, shuffle=False)
5042        >>>
5043        >>> # In VOC dataset, if task='Segmentation', each dictionary has keys "image" and "target"
5044        >>> # In VOC dataset, if task='Detection', each dictionary has keys "image" and "annotation"
5045
5046    About VOC dataset.
5047
5048    The PASCAL Visual Object Classes (VOC) challenge is a benchmark in visual
5049    object category recognition and detection, providing the vision and machine
5050    learning communities with a standard dataset of images and annotation, and
5051    standard evaluation procedures.
5052
5053    You can unzip the original VOC-2012 dataset files into this directory structure and read by MindSpore's API.
5054
5055    .. code-block::
5056
5057        .
5058        └── voc2012_dataset_dir
5059            ├── Annotations
5060            │    ├── 2007_000027.xml
5061            │    ├── 2007_000032.xml
5062            │    ├── ...
5063            ├── ImageSets
5064            │    ├── Action
5065            │    ├── Layout
5066            │    ├── Main
5067            │    └── Segmentation
5068            ├── JPEGImages
5069            │    ├── 2007_000027.jpg
5070            │    ├── 2007_000032.jpg
5071            │    ├── ...
5072            ├── SegmentationClass
5073            │    ├── 2007_000032.png
5074            │    ├── 2007_000033.png
5075            │    ├── ...
5076            └── SegmentationObject
5077                 ├── 2007_000032.png
5078                 ├── 2007_000033.png
5079                 ├── ...
5080
5081    Citation:
5082
5083    .. code-block::
5084
5085        @article{Everingham10,
5086        author       = {Everingham, M. and Van~Gool, L. and Williams, C. K. I. and Winn, J. and Zisserman, A.},
5087        title        = {The Pascal Visual Object Classes (VOC) Challenge},
5088        journal      = {International Journal of Computer Vision},
5089        volume       = {88},
5090        year         = {2012},
5091        number       = {2},
5092        month        = {jun},
5093        pages        = {303--338},
5094        biburl       = {http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham10.html#bibtex},
5095        howpublished = {http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html}
5096        }
5097    """
5098
5099    @check_vocdataset
5100    def __init__(self, dataset_dir, task="Segmentation", usage="train", class_indexing=None, num_samples=None,
5101                 num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None,
5102                 cache=None, extra_metadata=False):
5103        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
5104                         shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
5105        self.dataset_dir = dataset_dir
5106        self.task = replace_none(task, "Segmentation")
5107        self.usage = replace_none(usage, "train")
5108        self.class_indexing = replace_none(class_indexing, {})
5109        self.decode = replace_none(decode, False)
5110        self.extra_metadata = extra_metadata
5111
5112    def parse(self, children=None):
5113        return cde.VOCNode(self.dataset_dir, self.task, self.usage, self.class_indexing, self.decode, self.sampler,
5114                           self.extra_metadata)
5115
5116    def get_class_indexing(self):
5117        """
5118        Get the class index.
5119
5120        Returns:
5121            dict, a str-to-int mapping from label name to index.
5122
5123        Examples:
5124            >>> voc_dataset_dir = "/path/to/voc_dataset_directory"
5125            >>>
5126            >>> dataset = ds.VOCDataset(dataset_dir=voc_dataset_dir)
5127            >>> class_indexing = dataset.get_class_indexing()
5128        """
5129        if self.task != "Detection":
5130            raise NotImplementedError("Only 'Detection' support get_class_indexing.")
5131        if self.class_indexing is None or not self.class_indexing:
5132            if self._class_indexing is None:
5133                runtime_getter = self._init_tree_getters()
5134                self._class_indexing = runtime_getter[0].GetClassIndexing()
5135            self.class_indexing = {}
5136            for pair in self._class_indexing:
5137                self.class_indexing[pair[0]] = pair[1][0]
5138        return self.class_indexing
5139
5140
5141class CocoDataset(MappableDataset):
5142    """
5143    A source dataset for reading and parsing COCO dataset.
5144
5145    CocoDataset supports four kinds of tasks, which are Object Detection, Keypoint Detection, Stuff Segmentation and
5146    Panoptic Segmentation of 2017 Train/Val/Test dataset.
5147
5148    The generated dataset with different task setting has different output columns:
5149
5150    - task = :py:obj:`Detection`, output columns: :py:obj:`[image, dtype=uint8]`, :py:obj:`[bbox, dtype=float32]`, \
5151        :py:obj:`[category_id, dtype=uint32]`, :py:obj:`[iscrowd, dtype=uint32]`.
5152    - task = :py:obj:`Stuff`, output columns: :py:obj:`[image, dtype=uint8]`, :py:obj:`[segmentation,dtype=float32]`, \
5153        :py:obj:`[iscrowd,dtype=uint32]`.
5154    - task = :py:obj:`Keypoint`, output columns: :py:obj:`[image, dtype=uint8]`, \
5155        :py:obj:`[keypoints, dtype=float32]`, :py:obj:`[num_keypoints, dtype=uint32]`.
5156    - task = :py:obj:`Panoptic`, output columns: :py:obj:`[image, dtype=uint8]`, :py:obj:`[bbox, dtype=float32]`, \
5157        :py:obj:`[category_id, dtype=uint32]`, :py:obj:`[iscrowd, dtype=uint32]`, :py:obj:`[area, dtype=uint32]`.
5158
5159    Args:
5160        dataset_dir (str): Path to the root directory that contains the dataset.
5161        annotation_file (str): Path to the annotation JSON file.
5162        task (str, optional): Set the task type for reading COCO data. Supported task types:
5163            `Detection`, `Stuff`, `Panoptic` and `Keypoint` (default=`Detection`).
5164        num_samples (int, optional): The number of images to be included in the dataset
5165            (default=None, all images).
5166        num_parallel_workers (int, optional): Number of workers to read the data
5167            (default=None, number set in the configuration file).
5168        shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
5169            order behavior shown in the table).
5170        decode (bool, optional): Decode the images after reading (default=False).
5171        sampler (Sampler, optional): Object used to choose samples from the dataset
5172            (default=None, expected order behavior shown in the table).
5173        num_shards (int, optional): Number of shards that the dataset will be divided
5174            into (default=None). When this argument is specified, `num_samples` reflects
5175            the maximum sample number of per shard.
5176        shard_id (int, optional): The shard ID within num_shards (default=None). This
5177            argument can only be specified when num_shards is also specified.
5178        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
5179            (default=None, which means no cache is used).
5180        extra_metadata(bool, optional): Flag to add extra meta-data to row. If True, an additional column will be
5181            output at the end :py:obj:`[_meta-filename, dtype=string]` (default=False).
5182
5183    Raises:
5184        RuntimeError: If dataset_dir does not contain data files.
5185        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
5186        RuntimeError: If sampler and shuffle are specified at the same time.
5187        RuntimeError: If sampler and sharding are specified at the same time.
5188        RuntimeError: If num_shards is specified but shard_id is None.
5189        RuntimeError: If shard_id is specified but num_shards is None.
5190        RuntimeError: If parse JSON file failed.
5191        ValueError: If task is not in [`Detection`, `Stuff`, `Panoptic`, `Keypoint`].
5192        ValueError: If annotation_file is not exist.
5193        ValueError: If dataset_dir is not exist.
5194        ValueError: If shard_id is invalid (< 0 or >= num_shards).
5195
5196    Note:
5197        - Column '[_meta-filename, dtype=string]' won't be output unless an explicit rename dataset op is added
5198          to remove the prefix('_meta-').
5199        - CocoDataset doesn't support PKSampler.
5200        - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
5201          The table below shows what input arguments are allowed and their expected behavior.
5202
5203    .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
5204       :widths: 25 25 50
5205       :header-rows: 1
5206
5207       * - Parameter `sampler`
5208         - Parameter `shuffle`
5209         - Expected Order Behavior
5210       * - None
5211         - None
5212         - random order
5213       * - None
5214         - True
5215         - random order
5216       * - None
5217         - False
5218         - sequential order
5219       * - Sampler object
5220         - None
5221         - order defined by sampler
5222       * - Sampler object
5223         - True
5224         - not allowed
5225       * - Sampler object
5226         - False
5227         - not allowed
5228
5229    Examples:
5230        >>> coco_dataset_dir = "/path/to/coco_dataset_directory/images"
5231        >>> coco_annotation_file = "/path/to/coco_dataset_directory/annotation_file"
5232        >>>
5233        >>> # 1) Read COCO data for Detection task
5234        >>> dataset = ds.CocoDataset(dataset_dir=coco_dataset_dir,
5235        ...                          annotation_file=coco_annotation_file,
5236        ...                          task='Detection')
5237        >>>
5238        >>> # 2) Read COCO data for Stuff task
5239        >>> dataset = ds.CocoDataset(dataset_dir=coco_dataset_dir,
5240        ...                          annotation_file=coco_annotation_file,
5241        ...                          task='Stuff')
5242        >>>
5243        >>> # 3) Read COCO data for Panoptic task
5244        >>> dataset = ds.CocoDataset(dataset_dir=coco_dataset_dir,
5245        ...                          annotation_file=coco_annotation_file,
5246        ...                          task='Panoptic')
5247        >>>
5248        >>> # 4) Read COCO data for Keypoint task
5249        >>> dataset = ds.CocoDataset(dataset_dir=coco_dataset_dir,
5250        ...                          annotation_file=coco_annotation_file,
5251        ...                          task='Keypoint')
5252        >>>
5253        >>> # In COCO dataset, each dictionary has keys "image" and "annotation"
5254
5255    About COCO dataset:
5256
5257    COCO(Microsoft Common Objects in Context) is a large-scale object detection, segmentation, and captioning dataset
5258    with several features: Object segmentation, Recognition in context, Superpixel stuff segmentation,
5259    330K images (>200K labeled), 1.5 million object instances, 80 object categories, 91 stuff categories,
5260    5 captions per image, 250,000 people with keypoints. In contrast to the popular ImageNet dataset, COCO has fewer
5261    categories but more instances in per category.
5262
5263    You can unzip the original COCO-2017 dataset files into this directory structure and read by MindSpore's API.
5264
5265    .. code-block::
5266
5267        .
5268        └── coco_dataset_directory
5269             ├── train2017
5270             │    ├── 000000000009.jpg
5271             │    ├── 000000000025.jpg
5272             │    ├── ...
5273             ├── test2017
5274             │    ├── 000000000001.jpg
5275             │    ├── 000000058136.jpg
5276             │    ├── ...
5277             ├── val2017
5278             │    ├── 000000000139.jpg
5279             │    ├── 000000057027.jpg
5280             │    ├── ...
5281             └── annotations
5282                  ├── captions_train2017.json
5283                  ├── captions_val2017.json
5284                  ├── instances_train2017.json
5285                  ├── instances_val2017.json
5286                  ├── person_keypoints_train2017.json
5287                  └── person_keypoints_val2017.json
5288
5289    Citation:
5290
5291    .. code-block::
5292
5293        @article{DBLP:journals/corr/LinMBHPRDZ14,
5294        author        = {Tsung{-}Yi Lin and Michael Maire and Serge J. Belongie and
5295                        Lubomir D. Bourdev and  Ross B. Girshick and James Hays and
5296                        Pietro Perona and Deva Ramanan and Piotr Doll{\'{a}}r and C. Lawrence Zitnick},
5297        title         = {Microsoft {COCO:} Common Objects in Context},
5298        journal       = {CoRR},
5299        volume        = {abs/1405.0312},
5300        year          = {2014},
5301        url           = {http://arxiv.org/abs/1405.0312},
5302        archivePrefix = {arXiv},
5303        eprint        = {1405.0312},
5304        timestamp     = {Mon, 13 Aug 2018 16:48:13 +0200},
5305        biburl        = {https://dblp.org/rec/journals/corr/LinMBHPRDZ14.bib},
5306        bibsource     = {dblp computer science bibliography, https://dblp.org}
5307        }
5308    """
5309
5310    @check_cocodataset
5311    def __init__(self, dataset_dir, annotation_file, task="Detection", num_samples=None, num_parallel_workers=None,
5312                 shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None,
5313                 extra_metadata=False):
5314        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
5315                         shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
5316        self.dataset_dir = dataset_dir
5317        self.annotation_file = annotation_file
5318        self.task = replace_none(task, "Detection")
5319        self.decode = replace_none(decode, False)
5320        self.extra_metadata = extra_metadata
5321
5322    def parse(self, children=None):
5323        return cde.CocoNode(self.dataset_dir, self.annotation_file, self.task, self.decode, self.sampler,
5324                            self.extra_metadata)
5325
5326    def get_class_indexing(self):
5327        """
5328        Get the class index.
5329
5330        Returns:
5331            dict, a str-to-list<int> mapping from label name to index.
5332
5333        Examples:
5334            >>> coco_dataset_dir = "/path/to/coco_dataset_directory/images"
5335            >>> coco_annotation_file = "/path/to/coco_dataset_directory/annotation_file"
5336            >>>
5337            >>> # Read COCO data for Detection task
5338            >>> dataset = ds.CocoDataset(dataset_dir=coco_dataset_dir,
5339            ...                          annotation_file=coco_annotation_file,
5340            ...                          task='Detection')
5341            >>>
5342            >>> class_indexing = dataset.get_class_indexing()
5343        """
5344        if self.task not in {"Detection", "Panoptic"}:
5345            raise NotImplementedError("Only 'Detection' and 'Panoptic' support get_class_indexing.")
5346        if self._class_indexing is None:
5347            runtime_getter = self._init_tree_getters()
5348            self._class_indexing = dict(runtime_getter[0].GetClassIndexing())
5349        return self._class_indexing
5350
5351
5352class CelebADataset(MappableDataset):
5353    """
5354    A source dataset for reading and parsing CelebA dataset.
5355    Only support to read `list_attr_celeba.txt` currently, which is the attribute annotations of the dataset.
5356
5357    The generated dataset has two columns: :py:obj:`[image, attr]`.
5358    The tensor of column :py:obj:`image` is of the uint8 type.
5359    The tensor of column :py:obj:`attr` is of the uint32 type and one hot encoded.
5360
5361    Args:
5362        dataset_dir (str): Path to the root directory that contains the dataset.
5363        num_parallel_workers (int, optional): Number of workers to read the data (default=None, will use value set in
5364            the config).
5365        shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None).
5366        usage (str, optional): Specify the `train`, `valid`, `test` part or `all` parts of dataset
5367            (default=`all`, will read all samples).
5368        sampler (Sampler, optional): Object used to choose samples from the dataset (default=None).
5369        decode (bool, optional): decode the images after reading (default=False).
5370        extensions (list[str], optional): List of file extensions to be included in the dataset (default=None).
5371        num_samples (int, optional): The number of images to be included in the dataset
5372            (default=None, will include all images).
5373        num_shards (int, optional): Number of shards that the dataset will be divided
5374            into (default=None). When this argument is specified, `num_samples` reflects
5375            the maximum sample number of per shard.
5376        shard_id (int, optional): The shard ID within `num_shards` (default=None). This
5377            argument can only be specified when `num_shards` is also specified.
5378        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
5379            (default=None, which means no cache is used).
5380
5381    Raises:
5382        RuntimeError: If dataset_dir does not contain data files.
5383        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
5384        RuntimeError: If sampler and shuffle are specified at the same time.
5385        RuntimeError: If sampler and sharding are specified at the same time.
5386        RuntimeError: If num_shards is specified but shard_id is None.
5387        RuntimeError: If shard_id is specified but num_shards is None.
5388        ValueError: If shard_id is invalid (< 0 or >= num_shards).
5389
5390    Note:
5391        - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
5392          The table below shows what input arguments are allowed and their expected behavior.
5393
5394    .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
5395       :widths: 25 25 50
5396       :header-rows: 1
5397
5398       * - Parameter `sampler`
5399         - Parameter `shuffle`
5400         - Expected Order Behavior
5401       * - None
5402         - None
5403         - random order
5404       * - None
5405         - True
5406         - random order
5407       * - None
5408         - False
5409         - sequential order
5410       * - Sampler object
5411         - None
5412         - order defined by sampler
5413       * - Sampler object
5414         - True
5415         - not allowed
5416       * - Sampler object
5417         - False
5418         - not allowed
5419
5420    Examples:
5421        >>> celeba_dataset_dir = "/path/to/celeba_dataset_directory"
5422        >>>
5423        >>> # Read 5 samples from CelebA dataset
5424        >>> dataset = ds.CelebADataset(dataset_dir=celeba_dataset_dir, usage='train', num_samples=5)
5425        >>>
5426        >>> # Note: In celeba dataset, each data dictionary owns keys "image" and "attr"
5427
5428    About CelebA dataset:
5429
5430    CelebFaces Attributes Dataset (CelebA) is a large-scale face attributes dataset
5431    with more than 200K celebrity images, each with 40 attribute annotations.
5432
5433    The images in this dataset cover large pose variations and background clutter.
5434    CelebA has large diversities, large quantities, and rich annotations, including
5435
5436    * 10,177 number of identities,
5437    * 202,599 number of face images, and
5438    * 5 landmark locations, 40 binary attributes annotations per image.
5439
5440    The dataset can be employed as the training and test sets for the following computer
5441    vision tasks: face attribute recognition, face detection, landmark (or facial part)
5442    localization, and face editing & synthesis.
5443
5444    Original CelebA dataset structure:
5445
5446    .. code-block::
5447
5448        .
5449        └── CelebA
5450             ├── README.md
5451             ├── Img
5452             │    ├── img_celeba.7z
5453             │    ├── img_align_celeba_png.7z
5454             │    └── img_align_celeba.zip
5455             ├── Eval
5456             │    └── list_eval_partition.txt
5457             └── Anno
5458                  ├── list_landmarks_celeba.txt
5459                  ├── list_landmarks_align_celeba.txt
5460                  ├── list_bbox_celeba.txt
5461                  ├── list_attr_celeba.txt
5462                  └── identity_CelebA.txt
5463
5464    You can unzip the dataset files into the following structure and read by MindSpore's API.
5465
5466    .. code-block::
5467
5468        .
5469        └── celeba_dataset_directory
5470            ├── list_attr_celeba.txt
5471            ├── 000001.jpg
5472            ├── 000002.jpg
5473            ├── 000003.jpg
5474            ├── ...
5475
5476    Citation:
5477
5478    .. code-block::
5479
5480        @article{DBLP:journals/corr/LiuLWT14,
5481        author        = {Ziwei Liu and Ping Luo and Xiaogang Wang and Xiaoou Tang},
5482        title         = {Deep Learning Face Attributes in the Wild},
5483        journal       = {CoRR},
5484        volume        = {abs/1411.7766},
5485        year          = {2014},
5486        url           = {http://arxiv.org/abs/1411.7766},
5487        archivePrefix = {arXiv},
5488        eprint        = {1411.7766},
5489        timestamp     = {Tue, 10 Dec 2019 15:37:26 +0100},
5490        biburl        = {https://dblp.org/rec/journals/corr/LiuLWT14.bib},
5491        bibsource     = {dblp computer science bibliography, https://dblp.org},
5492        howpublished  = {http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html}
5493        }
5494    """
5495
5496    @check_celebadataset
5497    def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False,
5498                 extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None):
5499        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
5500                         shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
5501        self.dataset_dir = dataset_dir
5502        self.decode = replace_none(decode, False)
5503        self.extensions = replace_none(extensions, [])
5504        self.usage = replace_none(usage, "all")
5505
5506    def parse(self, children=None):
5507        if self.usage != "all":
5508            dataset_dir = os.path.realpath(self.dataset_dir)
5509            partition_file = os.path.join(dataset_dir, "list_eval_partition.txt")
5510            if os.path.exists(partition_file) is False:
5511                raise RuntimeError("Partition file can not be found when usage is not 'all'.")
5512        return cde.CelebANode(self.dataset_dir, self.usage, self.sampler, self.decode, self.extensions)
5513
5514
5515class CLUEDataset(SourceDataset):
5516    """
5517    A source dataset that reads and parses CLUE datasets.
5518    Supported CLUE classification tasks: `AFQMC`, `TNEWS`, `IFLYTEK`, `CMNLI`, `WSC` and `CSL`.
5519
5520    The generated dataset with different task setting has different output columns:
5521
5522    - task = :py:obj:`AFQMC`
5523        - usage = :py:obj:`train`, output columns: :py:obj:`[sentence1, dtype=string]`, \
5524            :py:obj:`[sentence2, dtype=string]`, :py:obj:`[label, dtype=string]`.
5525        - usage = :py:obj:`test`, output columns: :py:obj:`[id, dtype=uint8]`, \
5526            :py:obj:`[sentence1, dtype=string]`, :py:obj:`[sentence2, dtype=string]`.
5527        - usage = :py:obj:`eval`, output columns: :py:obj:`[sentence1, dtype=string]`, \
5528            :py:obj:`[sentence2, dtype=string]`, :py:obj:`[label, dtype=string]`.
5529
5530    - task = :py:obj:`TNEWS`
5531        - usage = :py:obj:`train`, output columns: :py:obj:`[label, dtype=string]`, \
5532            :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`, :py:obj:`[keywords, dtype=string]`.
5533        - usage = :py:obj:`test`, output columns: :py:obj:`[label, dtype=string]`, \
5534            :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`, :py:obj:`[keywords, dtype=string]`.
5535        - usage = :py:obj:`eval`, output columns: :py:obj:`[label, dtype=string]`, \
5536            :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`, :py:obj:`[keywords, dtype=string]`.
5537
5538    - task = :py:obj:`IFLYTEK`
5539        - usage = :py:obj:`train`, output columns: :py:obj:`[label, dtype=string]`, \
5540            :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`.
5541        - usage = :py:obj:`test`, output columns: :py:obj:`[id, dtype=string]`, \
5542            :py:obj:`[sentence, dtype=string]`.
5543        - usage = :py:obj:`eval`, output columns: :py:obj:`[label, dtype=string]`, \
5544            :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`.
5545
5546    - task = :py:obj:`CMNLI`
5547        - usage = :py:obj:`train`, output columns: :py:obj:`[sentence1, dtype=string]`, \
5548            :py:obj:`[sentence2, dtype=string]`, :py:obj:`[label, dtype=string]`.
5549        - usage = :py:obj:`test`, output columns: :py:obj:`[id, dtype=uint8]`, \
5550            :py:obj:`[sentence1, dtype=string]`, :py:obj:`[sentence2, dtype=string]`.
5551        - usage = :py:obj:`eval`, output columns: :py:obj:`[sentence1, dtype=string]`, \
5552            :py:obj:`[sentence2, dtype=string]`, :py:obj:`[label, dtype=string]`.
5553
5554    - task = :py:obj:`WSC`
5555        - usage = :py:obj:`train`, output columns: :py:obj:`[span1_index, dtype=uint8]`, \
5556            :py:obj:`[span2_index, dtype=uint8]`, :py:obj:`[span1_text, dtype=string]`, \
5557            :py:obj:`[span2_text, dtype=string]`, :py:obj:`[idx, dtype=uint8]`, \
5558            :py:obj:`[text, dtype=string]`, :py:obj:`[label, dtype=string]`.
5559        - usage = output columns: :py:obj:`[span1_index, dtype=uint8]`, \
5560            :py:obj:`[span2_index, dtype=uint8]`, :py:obj:`[span1_text, dtype=string]`, \
5561            :py:obj:`[span2_text, dtype=string]`, :py:obj:`[idx, dtype=uint8]`, :py:obj:`[text, dtype=string]`.
5562        - usage = :py:obj:`eval`, output columns: :py:obj:`[span1_index, dtype=uint8]`, \
5563            :py:obj:`[span2_index, dtype=uint8]`, :py:obj:`[span1_text, dtype=string]`, \
5564            :py:obj:`[span2_text, dtype=string]`, :py:obj:`[idx, dtype=uint8]`, \
5565            :py:obj:`[text, dtype=string]`, :py:obj:`[label, dtype=string]`.
5566
5567    - task = :py:obj:`CSL`
5568        - usage = :py:obj:`train`, output columns: :py:obj:`[id, dtype=uint8]`, \
5569            :py:obj:`[abst, dtype=string]`, :py:obj:`[keyword, dtype=string]`, :py:obj:`[label, dtype=string]`.
5570        - usage = :py:obj:`test`, output columns: :py:obj:`[id, dtype=uint8]`, \
5571            :py:obj:`[abst, dtype=string]`, :py:obj:`[keyword, dtype=string]`.
5572        - usage = :py:obj:`eval`, output columns: :py:obj:`[id, dtype=uint8]`, \
5573            :py:obj:`[abst, dtype=string]`, :py:obj:`[keyword, dtype=string]`, :py:obj:`[label, dtype=string]`.
5574
5575    Args:
5576        dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for
5577            a pattern of files. The list will be sorted in a lexicographical order.
5578        task (str, optional): The kind of task, one of `AFQMC`, `TNEWS`, `IFLYTEK`, `CMNLI`, `WSC` and `CSL`.
5579            (default=AFQMC).
5580        usage (str, optional): Specify the `train`, `test` or `eval` part of dataset (default="train").
5581        num_samples (int, optional): The number of samples to be included in the dataset
5582            (default=None, will include all images).
5583        num_parallel_workers (int, optional): Number of workers to read the data
5584            (default=None, number set in the config).
5585        shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
5586            (default=Shuffle.GLOBAL).
5587            If shuffle is False, no shuffling will be performed;
5588            If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
5589            Otherwise, there are two levels of shuffling:
5590
5591            - Shuffle.GLOBAL: Shuffle both the files and samples.
5592
5593            - Shuffle.FILES: Shuffle files only.
5594
5595        num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
5596            When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
5597        shard_id (int, optional): The shard ID within num_shards (default=None). This
5598            argument can only be specified when num_shards is also specified.
5599        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
5600            (default=None, which means no cache is used).
5601
5602    Raises:
5603        RuntimeError: If dataset_files are not valid or do not exist.
5604        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
5605        RuntimeError: If num_shards is specified but shard_id is None.
5606        RuntimeError: If shard_id is specified but num_shards is None.
5607
5608    Examples:
5609        >>> clue_dataset_dir = ["/path/to/clue_dataset_file"] # contains 1 or multiple clue files
5610        >>> dataset = ds.CLUEDataset(dataset_files=clue_dataset_dir, task='AFQMC', usage='train')
5611
5612    About CLUE dataset:
5613
5614    CLUE, a Chinese Language Understanding Evaluation benchmark. It contains multiple
5615    tasks, including single-sentence classification, sentence pair classification, and machine
5616    reading comprehension.
5617
5618    You can unzip the dataset files into the following structure and read by MindSpore's API,
5619    such as afqmc dataset:
5620
5621    .. code-block::
5622
5623        .
5624        └── afqmc_public
5625             ├── train.json
5626             ├── test.json
5627             └── dev.json
5628
5629    Citation:
5630
5631    .. code-block::
5632
5633        @article{CLUEbenchmark,
5634        title   = {CLUE: A Chinese Language Understanding Evaluation Benchmark},
5635        author  = {Liang Xu, Xuanwei Zhang, Lu Li, Hai Hu, Chenjie Cao, Weitang Liu, Junyi Li, Yudong Li,
5636                Kai Sun, Yechen Xu, Yiming Cui, Cong Yu, Qianqian Dong, Yin Tian, Dian Yu, Bo Shi, Jun Zeng,
5637                Rongzhao Wang, Weijian Xie, Yanting Li, Yina Patterson, Zuoyu Tian, Yiwen Zhang, He Zhou,
5638                Shaoweihua Liu, Qipeng Zhao, Cong Yue, Xinrui Zhang, Zhengliang Yang, Zhenzhong Lan},
5639        journal = {arXiv preprint arXiv:2004.05986},
5640        year    = {2020},
5641        howpublished = {https://github.com/CLUEbenchmark/CLUE}
5642        }
5643    """
5644
5645    @check_cluedataset
5646    def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None, num_parallel_workers=None,
5647                 shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
5648        super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
5649                         num_shards=num_shards, shard_id=shard_id, cache=cache)
5650        self.dataset_files = self._find_files(dataset_files)
5651        self.usage = replace_none(usage, 'train')
5652        self.task = replace_none(task, 'AFQMC')
5653
5654    def parse(self, children=None):
5655        return cde.CLUENode(self.dataset_files, self.task, self.usage, self.num_samples, self.shuffle_flag,
5656                            self.num_shards, self.shard_id)
5657
5658
5659class CSVDataset(SourceDataset):
5660    """
5661    A source dataset that reads and parses comma-separated values (CSV) datasets.
5662    The columns of generated dataset depend on the source CSV files.
5663
5664    Args:
5665        dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search
5666            for a pattern of files. The list will be sorted in a lexicographical order.
5667        field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=',').
5668        column_defaults (list, optional): List of default values for the CSV field (default=None). Each item
5669            in the list is either a valid type (float, int, or string). If this is not provided, treats all
5670            columns as string type.
5671        column_names (list[str], optional): List of column names of the dataset (default=None). If this
5672            is not provided, infers the column_names from the first row of CSV file.
5673        num_samples (int, optional): The number of samples to be included in the dataset
5674            (default=None, will include all images).
5675        num_parallel_workers (int, optional): Number of workers to read the data
5676            (default=None, number set in the config).
5677        shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
5678            (default=Shuffle.GLOBAL).
5679            If shuffle is False, no shuffling will be performed;
5680            If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
5681            Otherwise, there are two levels of shuffling:
5682
5683            - Shuffle.GLOBAL: Shuffle both the files and samples.
5684
5685            - Shuffle.FILES: Shuffle files only.
5686
5687        num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
5688            When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
5689        shard_id (int, optional): The shard ID within num_shards (default=None). This
5690            argument can only be specified when num_shards is also specified.
5691        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
5692            (default=None, which means no cache is used).
5693
5694    Raises:
5695        RuntimeError: If dataset_files are not valid or do not exist.
5696        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
5697        RuntimeError: If num_shards is specified but shard_id is None.
5698        RuntimeError: If shard_id is specified but num_shards is None.
5699
5700    Examples:
5701        >>> csv_dataset_dir = ["/path/to/csv_dataset_file"] # contains 1 or multiple csv files
5702        >>> dataset = ds.CSVDataset(dataset_files=csv_dataset_dir, column_names=['col1', 'col2', 'col3', 'col4'])
5703    """
5704
5705    @check_csvdataset
5706    def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None,
5707                 num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
5708        super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
5709                         num_shards=num_shards, shard_id=shard_id, cache=cache)
5710        self.dataset_files = self._find_files(dataset_files)
5711        self.dataset_files.sort()
5712        self.field_delim = replace_none(field_delim, ',')
5713        self.column_defaults = replace_none(column_defaults, [])
5714        self.column_names = replace_none(column_names, [])
5715
5716    def parse(self, children=None):
5717        return cde.CSVNode(self.dataset_files, self.field_delim, self.column_defaults, self.column_names,
5718                           self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id)
5719
5720
5721class SBUDataset(MappableDataset):
5722    """
5723    A source dataset for reading and parsing the SBU dataset.
5724
5725    The generated dataset has two columns :py:obj:`[image, caption]`.
5726    The tensor of column :py:obj:`image` is of the uint8 type.
5727    The tensor of column :py:obj:`caption` is of the string type.
5728
5729    Args:
5730        dataset_dir (str): Path to the root directory that contains the dataset.
5731        decode (bool, optional): Decode the images after reading (default=False).
5732        num_samples (int, optional): The number of images to be included in the dataset
5733            (default=None, will read all images).
5734        num_parallel_workers (int, optional): Number of workers to read the data
5735            (default=None, will use value set in the config).
5736        shuffle (bool, optional): Whether or not to perform shuffle on the dataset
5737            (default=None, expected order behavior shown in the table).
5738        sampler (Sampler, optional): Object used to choose samples from the
5739            dataset (default=None, expected order behavior shown in the table).
5740        num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
5741            When this argument is specified, `num_samples` reflects the max sample number of per shard.
5742        shard_id (int, optional): The shard ID within `num_shards` (default=None). This
5743            argument can only be specified when `num_shards` is also specified.
5744        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
5745            (default=None, which means no cache is used).
5746
5747    Raises:
5748        RuntimeError: If dataset_dir does not contain data files.
5749        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
5750        RuntimeError: If sampler and shuffle are specified at the same time.
5751        RuntimeError: If sampler and sharding are specified at the same time.
5752        RuntimeError: If num_shards is specified but shard_id is None.
5753        RuntimeError: If shard_id is specified but num_shards is None.
5754        ValueError: If shard_id is invalid (< 0 or >= num_shards).
5755
5756    Note:
5757        - This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive.
5758          The table below shows what input arguments are allowed and their expected behavior.
5759
5760    .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
5761       :widths: 25 25 50
5762       :header-rows: 1
5763
5764       * - Parameter 'sampler'
5765         - Parameter 'shuffle'
5766         - Expected Order Behavior
5767       * - None
5768         - None
5769         - random order
5770       * - None
5771         - True
5772         - random order
5773       * - None
5774         - False
5775         - sequential order
5776       * - Sampler object
5777         - None
5778         - order defined by sampler
5779       * - Sampler object
5780         - True
5781         - not allowed
5782       * - Sampler object
5783         - False
5784         - not allowed
5785
5786    Examples:
5787        >>> sbu_dataset_dir = "/path/to/sbu_dataset_directory"
5788        >>> # Read 3 samples from SBU dataset
5789        >>> dataset = ds.SBUDataset(dataset_dir=sbu_dataset_dir, num_samples=3)
5790
5791    About SBU dataset:
5792
5793    SBU dataset is a large captioned photo collection.
5794    It contains one million images with associated visually relevant captions.
5795
5796    You should manually download the images using official download.m by replacing 'urls{i}(24, end)' with
5797    'urls{i}(24:1:end)' and keep the directory as below.
5798
5799    .. code-block::
5800
5801        .
5802        └─ dataset_dir
5803           ├── SBU_captioned_photo_dataset_captions.txt
5804           ├── SBU_captioned_photo_dataset_urls.txt
5805           └── sbu_images
5806               ├── m_3326_3596303505_3ce4c20529.jpg
5807               ├── ......
5808               └── m_2522_4182181099_c3c23ab1cc.jpg
5809
5810    Citation:
5811
5812    .. code-block::
5813
5814        @inproceedings{Ordonez:2011:im2text,
5815          Author    = {Vicente Ordonez and Girish Kulkarni and Tamara L. Berg},
5816          Title     = {Im2Text: Describing Images Using 1 Million Captioned Photographs},
5817          Booktitle = {Neural Information Processing Systems ({NIPS})},
5818          Year      = {2011},
5819        }
5820    """
5821
5822    @check_sbu_dataset
5823    def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False,
5824                 sampler=None, num_shards=None, shard_id=None, cache=None):
5825        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
5826                         shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
5827
5828        self.dataset_dir = dataset_dir
5829        self.decode = replace_none(decode, False)
5830
5831    def parse(self, children=None):
5832        return cde.SBUNode(self.dataset_dir, self.decode, self.sampler)
5833
5834
5835class _Flowers102Dataset:
5836    """
5837    Mainly for loading Flowers102 Dataset, and return one row each time.
5838    """
5839    def __init__(self, dataset_dir, task, usage, decode):
5840        self.dataset_dir = os.path.realpath(dataset_dir)
5841        self.task = task
5842        self.usage = usage
5843        self.decode = decode
5844
5845        if self.task == "Classification":
5846            self.column_names = ["image", "label"]
5847        else:
5848            self.column_names = ["image", "segmentation", "label"]
5849
5850        labels_path = os.path.join(self.dataset_dir, "imagelabels.mat")
5851        setid_path = os.path.join(self.dataset_dir, "setid.mat")
5852        # minus one to transform 1~102 to 0 ~ 101
5853        self.labels = (loadmat(labels_path)["labels"][0] - 1).astype(np.uint32)
5854        self.setid = loadmat(setid_path)
5855
5856        if self.usage == 'train':
5857            self.indices = self.setid["trnid"][0].tolist()
5858        elif self.usage == 'test':
5859            self.indices = self.setid["tstid"][0].tolist()
5860        elif self.usage == 'valid':
5861            self.indices = self.setid["valid"][0].tolist()
5862        elif self.usage == 'all':
5863            self.indices = self.setid["trnid"][0].tolist()
5864            self.indices += self.setid["tstid"][0].tolist()
5865            self.indices += self.setid["valid"][0].tolist()
5866        else:
5867            raise ValueError("Input usage is not within the valid set of ['train', 'valid', 'test', 'all'].")
5868
5869    def __getitem__(self, index):
5870        # range: 1 ~ 8189
5871        image_path = os.path.join(self.dataset_dir, "jpg", "image_" + str(self.indices[index]).zfill(5) + ".jpg")
5872        if not os.path.exists(image_path):
5873            raise RuntimeError("Can not find image file: " + image_path)
5874
5875        if self.decode is True:
5876            image = np.asarray(Image.open(image_path).convert("RGB"))
5877        else:
5878            image = np.fromfile(image_path, dtype=np.uint8)
5879
5880        label = self.labels[self.indices[index] - 1]
5881
5882        if self.task == "Segmentation":
5883            segmentation_path = \
5884                os.path.join(self.dataset_dir, "segmim", "segmim_" + str(self.indices[index]).zfill(5) + ".jpg")
5885            if not os.path.exists(segmentation_path):
5886                raise RuntimeError("Can not find segmentation file: " + segmentation_path)
5887            if self.decode is True:
5888                segmentation = np.asarray(Image.open(segmentation_path).convert("RGB"))
5889            else:
5890                segmentation = np.fromfile(segmentation_path, dtype=np.uint8)
5891            return image, segmentation, label
5892
5893        return image, label
5894
5895    def __len__(self):
5896        return len(self.indices)
5897
5898
5899class Flowers102Dataset(GeneratorDataset):
5900    """
5901    A source dataset for reading and parsing Flowers102 dataset.
5902
5903    The generated dataset has two columns :py:obj:`[image, label]` or three :py:obj:`[image, segmentation, label]`.
5904    The tensor of column :py:obj:`image` is of the uint8 type.
5905    The tensor of column :py:obj:`segmentation` is of the uint8 type.
5906    The tensor of column :py:obj:`label` is a scalar or a tensor of the uint32 type.
5907
5908    Args:
5909        dataset_dir (str): Path to the root directory that contains the dataset.
5910        task (str): Specify the 'Classification' or 'Segmentation' task (default='Classification').
5911        usage (str): Specify the 'train', 'valid', 'test' part or 'all' parts of dataset
5912            (default='all', will read all samples).
5913        num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images).
5914        num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
5915        shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
5916            (default=None, expected order behavior shown in the table).
5917        decode (bool, optional): Whether or not to decode the images and segmentations after reading (default=False).
5918        sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. Random accessible
5919            input is required (default=None, expected order behavior shown in the table).
5920        num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
5921            Random accessible input is required. When this argument is specified, 'num_samples' reflects the max
5922            sample number of per shard.
5923        shard_id (int, optional): The shard ID within num_shards (default=None). This argument must be specified only
5924            when num_shards is also specified. Random accessible input is required.
5925
5926    Raises:
5927        RuntimeError: If dataset_dir does not contain data files.
5928        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
5929        RuntimeError: If sampler and shuffle are specified at the same time.
5930        RuntimeError: If sampler and sharding are specified at the same time.
5931        RuntimeError: If num_shards is specified but shard_id is None.
5932        RuntimeError: If shard_id is specified but num_shards is None.
5933        ValueError: If shard_id is invalid (< 0 or >= num_shards).
5934
5935    Note:
5936        - This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive.
5937          The table below shows what input arguments are allowed and their expected behavior.
5938
5939    .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
5940       :widths: 25 25 50
5941       :header-rows: 1
5942
5943       * - Parameter 'sampler'
5944         - Parameter 'shuffle'
5945         - Expected Order Behavior
5946       * - None
5947         - None
5948         - random order
5949       * - None
5950         - True
5951         - random order
5952       * - None
5953         - False
5954         - sequential order
5955       * - Sampler object
5956         - None
5957         - order defined by sampler
5958       * - Sampler object
5959         - True
5960         - not allowed
5961       * - Sampler object
5962         - False
5963         - not allowed
5964
5965    Examples:
5966        >>> flowers102_dataset_dir = "/path/to/flowers102_dataset_directory"
5967        >>> dataset = ds.Flowers102Dataset(dataset_dir=flowers102_dataset_dir,
5968        ...                                task="Classification",
5969        ...                                usage="all",
5970        ...                                decode=True)
5971
5972    About Flowers102 dataset:
5973
5974    Flowers102 dataset consists of 102 flower categories.
5975    The flowers commonly occur in the United Kingdom.
5976    Each class consists of between 40 and 258 images.
5977
5978    Here is the original Flowers102 dataset structure.
5979    You can unzip the dataset files into this directory structure and read by MindSpore's API.
5980
5981    .. code-block::
5982        .
5983        └── flowes102_dataset_dir
5984             ├── imagelabels.mat
5985             ├── setid.mat
5986             ├── jpg
5987                  ├── image_00001.jpg
5988                  ├── image_00002.jpg
5989                  ├── ...
5990             ├── segmim
5991                  ├── segmim_00001.jpg
5992                  ├── segmim_00002.jpg
5993                  ├── ...
5994
5995    Citation:
5996
5997    .. code-block::
5998
5999        @InProceedings{Nilsback08,
6000          author       = "Maria-Elena Nilsback and Andrew Zisserman",
6001          title        = "Automated Flower Classification over a Large Number of Classes",
6002          booktitle    = "Indian Conference on Computer Vision, Graphics and Image Processing",
6003          month        = "Dec",
6004          year         = "2008",
6005        }
6006    """
6007
6008    @check_flowers102dataset
6009    def __init__(self, dataset_dir, task="Classification", usage="all", num_samples=None, num_parallel_workers=1,
6010                 shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
6011        self.dataset_dir = os.path.realpath(dataset_dir)
6012        self.task = replace_none(task, "Classification")
6013        self.usage = replace_none(usage, "all")
6014        self.decode = replace_none(decode, False)
6015        dataset = _Flowers102Dataset(self.dataset_dir, self.task, self.usage, self.decode)
6016        super().__init__(dataset, column_names=dataset.column_names, num_samples=num_samples,
6017                         num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
6018                         num_shards=num_shards, shard_id=shard_id)
6019
6020    def get_class_indexing(self):
6021        """
6022        Get the class index.
6023
6024        Returns:
6025            dict, a str-to-int mapping from label name to index.
6026        """
6027        class_names = [
6028            "pink primrose", "hard-leaved pocket orchid", "canterbury bells",
6029            "sweet pea", "english marigold", "tiger lily", "moon orchid",
6030            "bird of paradise", "monkshood", "globe thistle", "snapdragon",
6031            "colt's foot", "king protea", "spear thistle", "yellow iris",
6032            "globe-flower", "purple coneflower", "peruvian lily", "balloon flower",
6033            "giant white arum lily", "fire lily", "pincushion flower", "fritillary",
6034            "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers",
6035            "stemless gentian", "artichoke", "sweet william", "carnation",
6036            "garden phlox", "love in the mist", "mexican aster", "alpine sea holly",
6037            "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip",
6038            "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia",
6039            "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy",
6040            "common dandelion", "petunia", "wild pansy", "primula", "sunflower",
6041            "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia",
6042            "pink-yellow dahlia?", "cautleya spicata", "japanese anemone",
6043            "black-eyed susan", "silverbush", "californian poppy", "osteospermum",
6044            "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania",
6045            "azalea", "water lily", "rose", "thorn apple", "morning glory",
6046            "passion flower", "lotus", "toad lily", "anthurium", "frangipani",
6047            "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow",
6048            "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum",
6049            "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow",
6050            "mexican petunia", "bromelia", "blanket flower", "trumpet creeper",
6051            "blackberry lily"
6052        ]
6053
6054        class_dict = {}
6055        for i, class_name in enumerate(class_names):
6056            class_dict[class_name] = i
6057
6058        return class_dict
6059
6060
6061class TextFileDataset(SourceDataset):
6062    """
6063    A source dataset that reads and parses datasets stored on disk in text format.
6064    The generated dataset has one column :py:obj:`[text]` with type string.
6065
6066    Args:
6067        dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a
6068            pattern of files. The list will be sorted in a lexicographical order.
6069        num_samples (int, optional): The number of samples to be included in the dataset
6070            (default=None, will include all images).
6071        num_parallel_workers (int, optional): Number of workers to read the data
6072            (default=None, number set in the config).
6073        shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
6074            (default=Shuffle.GLOBAL).
6075            If shuffle is False, no shuffling will be performed;
6076            If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
6077            Otherwise, there are two levels of shuffling:
6078
6079            - Shuffle.GLOBAL: Shuffle both the files and samples.
6080
6081            - Shuffle.FILES: Shuffle files only.
6082
6083        num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
6084            When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
6085        shard_id (int, optional): The shard ID within num_shards (default=None). This
6086            argument can only be specified when num_shards is also specified.
6087        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
6088            (default=None, which means no cache is used).
6089
6090    Raises:
6091        RuntimeError: If dataset_files are not valid or do not exist.
6092        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
6093        RuntimeError: If num_shards is specified but shard_id is None.
6094        RuntimeError: If shard_id is specified but num_shards is None.
6095
6096    Examples:
6097        >>> text_file_dataset_dir = ["/path/to/text_file_dataset_file"] # contains 1 or multiple text files
6098        >>> dataset = ds.TextFileDataset(dataset_files=text_file_dataset_dir)
6099    """
6100
6101    @check_textfiledataset
6102    def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL,
6103                 num_shards=None, shard_id=None, cache=None):
6104        super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
6105                         num_shards=num_shards, shard_id=shard_id, cache=cache)
6106        self.dataset_files = self._find_files(dataset_files)
6107        self.dataset_files.sort()
6108
6109    def parse(self, children=None):
6110        return cde.TextFileNode(self.dataset_files, self.num_samples, self.shuffle_flag, self.num_shards,
6111                                self.shard_id)
6112
6113
6114class _NumpySlicesDataset:
6115    """
6116    Mainly for dealing with several kinds of formats of Python data, and return one row each time.
6117    """
6118
6119    def __init__(self, data, column_list=None):
6120        self.column_list = None
6121        # Convert dict data into tuple
6122        if isinstance(data, dict):
6123            data = self.process_dict(data)
6124
6125        if isinstance(data, tuple):
6126            self.data = ()
6127            data_len = len(data)
6128            for i in range(data_len):
6129                self.data = self.data + (np.array(data[i]),)
6130        else:
6131            self.data = (np.array(data),)
6132
6133        # check whether the data length in each column is equal
6134        data_len = [len(data_item) for data_item in self.data]
6135        if data_len[1:] != data_len[:-1]:
6136            raise ValueError("Data length in each column is not equal.")
6137
6138        # Init column_name
6139        if column_list is not None:
6140            self.column_list = column_list
6141        elif self.column_list is None:
6142            self.column_list = []
6143            column_num = len(self.data)
6144            for i in range(column_num):
6145                self.column_list.append("column_" + str(i))
6146
6147    def __getitem__(self, index):
6148        data_row = [d[index, ...] for d in self.data]
6149        data_res = tuple(data_row)
6150        return data_res
6151
6152    def __len__(self):
6153        return len(self.data[0])
6154
6155    def process_dict(self, input_data):
6156        """
6157        Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first.
6158        """
6159        # Convert pandas like dict(has "values" column) into General dict
6160        data_keys = list(input_data.keys())
6161        data_col = input_data[data_keys[0]]
6162        if hasattr(data_col, "values"):
6163            new_dict = {}
6164            for key in data_keys:
6165                item1 = input_data.pop(key)
6166                new_dict[key] = item1.values
6167            input_data = new_dict
6168
6169        # Convert the data in dict into tuple
6170        data = ()
6171        keys = list(input_data.keys())
6172        self.column_list = keys
6173        for key in keys:
6174            value = input_data[key]
6175            data = data + (list(value),)
6176
6177        return data
6178
6179
6180class NumpySlicesDataset(GeneratorDataset):
6181    """
6182    Creates a dataset with given data slices, mainly for loading Python data into dataset.
6183
6184    The column names and column types of generated dataset depend on Python data defined by users.
6185
6186    Args:
6187        data (Union[list, tuple, dict]) Input of given data. Supported data types include: list, tuple, dict and other
6188            NumPy formats. Input data will be sliced along the first dimension and generate additional rows, if input is
6189            list, there will be one column in each row, otherwise there tends to be multi columns. Large data is not
6190            recommended to be loaded in this way as data is loading into memory.
6191        column_names (list[str], optional): List of column names of the dataset (default=None). If column_names is not
6192            provided, the output column names will be named as the keys of dict when the input data is a dict,
6193            otherwise they will be named like column_0, column_1 ...
6194        num_samples (int, optional): The number of samples to be included in the dataset (default=None, all samples).
6195        num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
6196        shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
6197            (default=None, expected order behavior shown in the table).
6198        sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. Random accessible
6199            input is required (default=None, expected order behavior shown in the table).
6200        num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
6201            Random accessible input is required. When this argument is specified, `num_samples` reflects the max
6202            sample number of per shard.
6203        shard_id (int, optional): The shard ID within num_shards (default=None). This argument must be specified only
6204            when num_shards is also specified. Random accessible input is required.
6205
6206    Note:
6207        - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
6208          The table below shows what input arguments are allowed and their expected behavior.
6209
6210    .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
6211       :widths: 25 25 50
6212       :header-rows: 1
6213
6214       * - Parameter `sampler`
6215         - Parameter `shuffle`
6216         - Expected Order Behavior
6217       * - None
6218         - None
6219         - random order
6220       * - None
6221         - True
6222         - random order
6223       * - None
6224         - False
6225         - sequential order
6226       * - Sampler object
6227         - None
6228         - order defined by sampler
6229       * - Sampler object
6230         - True
6231         - not allowed
6232       * - Sampler object
6233         - False
6234         - not allowed
6235
6236    Raises:
6237        RuntimeError: If len of column_names does not match output len of data.
6238        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
6239        RuntimeError: If sampler and shuffle are specified at the same time.
6240        RuntimeError: If sampler and sharding are specified at the same time.
6241        RuntimeError: If num_shards is specified but shard_id is None.
6242        RuntimeError: If shard_id is specified but num_shards is None.
6243        ValueError: If shard_id is invalid (< 0 or >= num_shards).
6244
6245    Examples:
6246        >>> # 1) Input data can be a list
6247        >>> data = [1, 2, 3]
6248        >>> dataset = ds.NumpySlicesDataset(data=data, column_names=["column_1"])
6249        >>>
6250        >>> # 2) Input data can be a dictionary, and column_names will be its keys
6251        >>> data = {"a": [1, 2], "b": [3, 4]}
6252        >>> dataset = ds.NumpySlicesDataset(data=data)
6253        >>>
6254        >>> # 3) Input data can be a tuple of lists (or NumPy arrays), each tuple element refers to data in each column
6255        >>> data = ([1, 2], [3, 4], [5, 6])
6256        >>> dataset = ds.NumpySlicesDataset(data=data, column_names=["column_1", "column_2", "column_3"])
6257        >>>
6258        >>> # 4) Load data from CSV file
6259        >>> import pandas as pd
6260        >>> df = pd.read_csv(filepath_or_buffer=csv_dataset_dir[0])
6261        >>> dataset = ds.NumpySlicesDataset(data=dict(df), shuffle=False)
6262    """
6263
6264    @check_numpyslicesdataset
6265    def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None,
6266                 num_shards=None, shard_id=None):
6267        dataset = _NumpySlicesDataset(data, column_names)
6268        super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples,
6269                         num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
6270                         num_shards=num_shards, shard_id=shard_id)
6271
6272
6273class _PaddedDataset:
6274    """
6275    Mainly for combining false samples provided by users into a dataset.
6276
6277    Args:
6278        padded_samples (list(dict)): Data provided by user to be added to the initial Dataset.
6279    """
6280
6281    def __init__(self, padded_samples):
6282        self.column_names = list(padded_samples[0].keys())
6283        self.padded_samples = padded_samples
6284
6285    def __getitem__(self, item):
6286        return (self.padded_samples[item][key] for key in self.column_names)
6287
6288    def __len__(self):
6289        return len(self.padded_samples)
6290
6291
6292class PaddedDataset(GeneratorDataset):
6293    """
6294    Creates a dataset with filler data provided by user. Mainly used to add to the original data set
6295    and assign it to the corresponding shard.
6296
6297    Args:
6298        padded_samples (list(dict)): Samples provided by user.
6299
6300    Raises:
6301        TypeError: If padded_samples is not an instance of list.
6302        TypeError: If the element of padded_samples is not an instance of dict.
6303        ValueError: If the padded_samples is empty.
6304
6305    Examples:
6306        >>> import numpy as np
6307        >>> data = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}]
6308        >>> dataset = ds.PaddedDataset(padded_samples=data)
6309    """
6310
6311    @check_paddeddataset
6312    def __init__(self, padded_samples):
6313        dataset = _PaddedDataset(padded_samples)
6314        super().__init__(dataset, column_names=dataset.column_names, num_shards=None, shard_id=None, shuffle=False)
6315        self._dataset_size = len(dataset.padded_samples)
6316        self.padded_samples = padded_samples
6317
6318
6319class FlickrDataset(MappableDataset):
6320    """
6321    A source dataset for reading and parsing Flickr8k and Flickr30k dataset.
6322
6323    The generated dataset has two columns :py:obj:`[image, annotation]`.
6324    The tensor of column :py:obj:`image` is of the uint8 type.
6325    The tensor of column :py:obj:`annotation` is a tensor which contains 5 annotations string,
6326    such as ["a", "b", "c", "d", "e"].
6327
6328    Args:
6329        dataset_dir (str): Path to the root directory that contains the dataset.
6330        annotation_file (str): Path to the root directory that contains the annotation.
6331        num_samples (int, optional): The number of images to be included in the dataset.
6332            (default=None, all images).
6333        num_parallel_workers (int, optional): Number of workers to read the data
6334            (default=None, number set in the config).
6335        shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
6336            order behavior shown in the table).
6337        decode (bool, optional): Decode the images after reading (default=False).
6338        sampler (Sampler, optional): Object used to choose samples from the
6339            dataset (default=None, expected order behavior shown in the table).
6340        num_shards (int, optional): Number of shards that the dataset will be divided
6341            into (default=None). When this argument is specified, `num_samples` reflects
6342            the max sample number of per shard.
6343        shard_id (int, optional): The shard ID within num_shards (default=None). This
6344            argument can only be specified when num_shards is also specified.
6345        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
6346            (default=None, which means no cache is used).
6347
6348    Raises:
6349        RuntimeError: If dataset_dir is not valid or does not contain data files.
6350        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
6351        RuntimeError: If sampler and shuffle are specified at the same time.
6352        RuntimeError: If sampler and sharding are specified at the same time.
6353        RuntimeError: If num_shards is specified but shard_id is None.
6354        RuntimeError: If shard_id is specified but num_shards is None.
6355        ValueError: If dataset_dir is not exist.
6356        ValueError: If annotation_file is not exist.
6357        ValueError: If shard_id is invalid (< 0 or >= num_shards).
6358
6359    Note:
6360        - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
6361          The table below shows what input arguments are allowed and their expected behavior.
6362
6363    .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
6364       :widths: 25 25 50
6365       :header-rows: 1
6366
6367       * - Parameter `sampler`
6368         - Parameter `shuffle`
6369         - Expected Order Behavior
6370       * - None
6371         - None
6372         - random order
6373       * - None
6374         - True
6375         - random order
6376       * - None
6377         - False
6378         - sequential order
6379       * - Sampler object
6380         - None
6381         - order defined by sampler
6382       * - Sampler object
6383         - True
6384         - not allowed
6385       * - Sampler object
6386         - False
6387         - not allowed
6388
6389    Examples:
6390        >>> flickr_dataset_dir = "/path/to/flickr_dataset_directory"
6391        >>> annotation_file = "/path/to/flickr_annotation_file"
6392        >>>
6393        >>> # 1) Get all samples from FLICKR dataset in sequence
6394        >>> dataset = ds.FlickrDataset(dataset_dir=flickr_dataset_dir,
6395        ...                            annotation_file=annotation_file,
6396        ...                            shuffle=False)
6397        >>>
6398        >>> # 2) Randomly select 350 samples from FLICKR dataset
6399        >>> dataset = ds.FlickrDataset(dataset_dir=flickr_dataset_dir,
6400        ...                            annotation_file=annotation_file,
6401        ...                            num_samples=350,
6402        ...                            shuffle=True)
6403        >>>
6404        >>> # 3) Get samples from FLICKR dataset for shard 0 in a 2-way distributed training
6405        >>> dataset = ds.FlickrDataset(dataset_dir=flickr_dataset_dir,
6406        ...                            annotation_file=annotation_file,
6407        ...                            num_shards=2,
6408        ...                            shard_id=0)
6409        >>>
6410        >>> # In FLICKR dataset, each dictionary has keys "image" and "annotation"
6411
6412    About Flickr8k dataset:
6413
6414    The Flickr8k dataset consists of 8092 colour images. There are 40460 annotations in the Flickr8k.token.txt,
6415    each image has 5 annotations.
6416
6417    You can unzip the dataset files into the following directory structure and read by MindSpore's API.
6418
6419    .. code-block::
6420
6421        .
6422        └── Flickr8k
6423             ├── Flickr8k_Dataset
6424             │    ├── 1000268201_693b08cb0e.jpg
6425             │    ├── 1001773457_577c3a7d70.jpg
6426             │    ├── ...
6427             └── Flickr8k.token.txt
6428
6429    Citation:
6430
6431    .. code-block::
6432
6433        @article{DBLP:journals/jair/HodoshYH13,
6434        author    = {Micah Hodosh and Peter Young and Julia Hockenmaier},
6435        title     = {Framing Image Description as a Ranking Task: Data, Models and Evaluation Metrics},
6436        journal   = {J. Artif. Intell. Res.},
6437        volume    = {47},
6438        pages     = {853--899},
6439        year      = {2013},
6440        url       = {https://doi.org/10.1613/jair.3994},
6441        doi       = {10.1613/jair.3994},
6442        timestamp = {Mon, 21 Jan 2019 15:01:17 +0100},
6443        biburl    = {https://dblp.org/rec/journals/jair/HodoshYH13.bib},
6444        bibsource = {dblp computer science bibliography, https://dblp.org}
6445        }
6446
6447    About Flickr30k dataset:
6448
6449    The Flickr30k dataset consists of 31783 colour images. There are 158915 annotations in
6450    the results_20130124.token, each image has 5 annotations.
6451
6452    You can unzip the dataset files into the following directory structure and read by MindSpore's API.
6453
6454    Citation:
6455
6456    .. code-block::
6457
6458        .
6459        └── Flickr30k
6460             ├── flickr30k-images
6461             │    ├── 1000092795.jpg
6462             │    ├── 10002456.jpg
6463             │    ├── ...
6464             └── results_20130124.token
6465
6466    .. code-block::
6467
6468        @article{DBLP:journals/tacl/YoungLHH14,
6469        author    = {Peter Young and Alice Lai and Micah Hodosh and Julia Hockenmaier},
6470        title     = {From image descriptions to visual denotations: New similarity metrics
6471                     for semantic inference over event descriptions},
6472        journal   = {Trans. Assoc. Comput. Linguistics},
6473        volume    = {2},
6474        pages     = {67--78},
6475        year      = {2014},
6476        url       = {https://tacl2013.cs.columbia.edu/ojs/index.php/tacl/article/view/229},
6477        timestamp = {Wed, 17 Feb 2021 21:55:25 +0100},
6478        biburl    = {https://dblp.org/rec/journals/tacl/YoungLHH14.bib},
6479        bibsource = {dblp computer science bibliography, https://dblp.org}
6480        }
6481    """
6482
6483    @check_flickr_dataset
6484    def __init__(self, dataset_dir, annotation_file, num_samples=None, num_parallel_workers=None, shuffle=None,
6485                 decode=None, sampler=None, num_shards=None, shard_id=None, cache=None):
6486        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
6487                         shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
6488
6489        self.dataset_dir = dataset_dir
6490        self.annotation_file = annotation_file
6491        self.decode = replace_none(decode, False)
6492
6493    def parse(self, children=None):
6494        return cde.FlickrNode(self.dataset_dir, self.annotation_file, self.decode, self.sampler)
6495
6496
6497class SBDataset(GeneratorDataset):
6498    """
6499    A source dataset for reading and parsing Semantic Boundaries Dataset.
6500
6501    The generated dataset has two columns: :py:obj:`[image, task]`.
6502    The tensor of column :py:obj:`image` is of the uint8 type.
6503    The tensor of column :py:obj:`task` contains 20 images of the uint8 type if `task` is `Boundaries` otherwise
6504    contains 1 image of the uint8 type.
6505
6506    Args:
6507        dataset_dir (str): Path to the root directory that contains the dataset.
6508        task (str, optional): Acceptable tasks include `Boundaries` or `Segmentation` (default=`Boundaries`).
6509        usage (str, optional): Acceptable usages include `train`, `val`, `train_noval` and `all` (default=`all`).
6510        num_samples (int, optional): The number of images to be included in the dataset.
6511            (default=None, all images).
6512        num_parallel_workers (int, optional): Number of workers to read the data
6513            (default=None, number set in the config).
6514        shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
6515            order behavior shown in the table).
6516        sampler (Sampler, optional): Object used to choose samples from the
6517            dataset (default=None, expected order behavior shown in the table).
6518        num_shards (int, optional): Number of shards that the dataset will be divided
6519            into (default=None). When this argument is specified, `num_samples` reflects
6520            the max sample number of per shard.
6521        shard_id (int, optional): The shard ID within num_shards (default=None). This
6522            argument can only be specified when num_shards is also specified.
6523
6524    Raises:
6525        RuntimeError: If dataset_dir is not valid or does not contain data files.
6526        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
6527        RuntimeError: If sampler and shuffle are specified at the same time.
6528        RuntimeError: If sampler and sharding are specified at the same time.
6529        RuntimeError: If num_shards is specified but shard_id is None.
6530        RuntimeError: If shard_id is specified but num_shards is None.
6531        ValueError: If dataset_dir is not exist.
6532        ValueError: If task is not in [`Boundaries`, `Segmentation`].
6533        ValueError: If usage is not in [`train`, `val`, `train_noval`, `all`].
6534        ValueError: If shard_id is invalid (< 0 or >= num_shards).
6535
6536    Note:
6537        - This dataset can take in a sampler. `sampler` and `shuffle` are mutually exclusive.
6538          The table below shows what input arguments are allowed and their expected behavior.
6539
6540    .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
6541       :widths: 25 25 50
6542       :header-rows: 1
6543
6544       * - Parameter `sampler`
6545         - Parameter `shuffle`
6546         - Expected Order Behavior
6547       * - None
6548         - None
6549         - random order
6550       * - None
6551         - True
6552         - random order
6553       * - None
6554         - False
6555         - sequential order
6556       * - Sampler object
6557         - None
6558         - order defined by sampler
6559       * - Sampler object
6560         - True
6561         - not allowed
6562       * - Sampler object
6563         - False
6564         - not allowed
6565
6566    Examples:
6567        >>> sb_dataset_dir = "/path/to/sb_dataset_directory"
6568        >>>
6569        >>> # 1) Get all samples from Semantic Boundaries Dataset in sequence
6570        >>> dataset = ds.SBDataset(dataset_dir=sb_dataset_dir, shuffle=False)
6571        >>>
6572        >>> # 2) Randomly select 350 samples from Semantic Boundaries Dataset
6573        >>> dataset = ds.SBDataset(dataset_dir=sb_dataset_dir, num_samples=350, shuffle=True)
6574        >>>
6575        >>> # 3) Get samples from Semantic Boundaries Dataset for shard 0 in a 2-way distributed training
6576        >>> dataset = ds.SBDataset(dataset_dir=sb_dataset_dir, num_shards=2, shard_id=0)
6577        >>>
6578        >>> # In Semantic Boundaries Dataset, each dictionary has keys "image" and "task"
6579
6580    About Semantic Boundaries Dataset:
6581
6582    The Semantic Boundaries Dataset consists of 11355 colour images. There are 8498 images' name in the train.txt,
6583    2857 images' name in the val.txt and 5623 images' name in the train_noval.txt. The category cls/
6584    contains the Segmentation and Boundaries results of category-level, the category inst/ catains the
6585    Segmentation and Boundaries results of instance-level.
6586
6587    You can unzip the dataset files into the following structure and read by MindSpore's API:
6588
6589    .. code-block::
6590
6591         .
6592         └── benchmark_RELEASE
6593              ├── dataset
6594              ├── img
6595              │    ├── 2008_000002.jpg
6596              │    ├── 2008_000003.jpg
6597              │    ├── ...
6598              ├── cls
6599              │    ├── 2008_000002.mat
6600              │    ├── 2008_000003.mat
6601              │    ├── ...
6602              ├── inst
6603              │    ├── 2008_000002.mat
6604              │    ├── 2008_000003.mat
6605              │    ├── ...
6606              ├── train.txt
6607              └── val.txt
6608
6609    .. code-block::
6610
6611        @InProceedings{BharathICCV2011,
6612            author       = "Bharath Hariharan and Pablo Arbelaez and Lubomir Bourdev and
6613                            Subhransu Maji and Jitendra Malik",
6614            title        = "Semantic Contours from Inverse Detectors",
6615            booktitle    = "International Conference on Computer Vision (ICCV)",
6616            year         = "2011",
6617    """
6618
6619    @check_sb_dataset
6620    def __init__(self, dataset_dir, task='Boundaries', usage='all', num_samples=None, num_parallel_workers=1,
6621                 shuffle=None, decode=None, sampler=None, num_shards=None, shard_id=None):
6622        dataset = _SBDataset(dataset_dir, task, usage, decode)
6623        super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples,
6624                         num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
6625                         num_shards=num_shards, shard_id=shard_id)
6626
6627
6628class _SBDataset:
6629    """
6630    Dealing with the data file with .mat extension, and return one row in tuple (image, task) each time.
6631    """
6632
6633    def __init__(self, dataset_dir, task, usage, decode):
6634        self.column_list = ['image', 'task']
6635        self.task = task
6636        self.images_path = os.path.join(dataset_dir, 'img')
6637        self.cls_path = os.path.join(dataset_dir, 'cls')
6638        self._loadmat = loadmat
6639        self.categories = 20
6640        self.decode = replace_none(decode, False)
6641
6642        if usage == "all":
6643            image_names = []
6644            for item in ["train", "val"]:
6645                usage_path = os.path.join(dataset_dir, item + '.txt')
6646                if not os.path.exists(usage_path):
6647                    raise FileNotFoundError("SBDataset: {0} not found".format(usage_path))
6648                with open(usage_path, 'r') as f:
6649                    image_names += [x.strip() for x in f.readlines()]
6650        else:
6651            usage_path = os.path.join(dataset_dir, usage + '.txt')
6652            if not os.path.exists(usage_path):
6653                raise FileNotFoundError("SBDataset: {0} not found".format(usage_path))
6654            with open(usage_path, 'r') as f:
6655                image_names = [x.strip() for x in f.readlines()]
6656
6657        self.images = [os.path.join(self.images_path, i + ".jpg") for i in image_names]
6658        self.clss = [os.path.join(self.cls_path, i + ".mat") for i in image_names]
6659
6660        if len(self.images) != len(self.clss):
6661            raise ValueError("SBDataset: images count not equal to cls count")
6662
6663        self._get_data = self._get_boundaries_data if self.task == "Boundaries" else self._get_segmentation_data
6664        self._get_item = self._get_decode_item if self.decode else self._get_undecode_item
6665
6666    def _get_boundaries_data(self, mat_path):
6667        mat_data = self._loadmat(mat_path)
6668        return np.concatenate([np.expand_dims(mat_data['GTcls'][0][self.task][0][i][0].toarray(), axis=0)
6669                               for i in range(self.categories)], axis=0)
6670
6671    def _get_segmentation_data(self, mat_path):
6672        mat_data = self._loadmat(mat_path)
6673        return Image.fromarray(mat_data['GTcls'][0][self.task][0])
6674
6675    def _get_decode_item(self, idx):
6676        return Image.open(self.images[idx]).convert('RGB'), self._get_data(self.clss[idx])
6677
6678    def _get_undecode_item(self, idx):
6679        return np.fromfile(self.images[idx], dtype=np.uint8), self._get_data(self.clss[idx])
6680
6681    def __len__(self):
6682        return len(self.images)
6683
6684    def __getitem__(self, idx):
6685        return self._get_item(idx)
6686
6687
6688class DeserializedDataset(Dataset):
6689    def __init__(self, input_obj):
6690        super().__init__()
6691        self.input_obj = input_obj
6692
6693    def parse(self, children=None):
6694        if isinstance(self.input_obj, dict):
6695            json_str = json.dumps(self.input_obj)
6696            return cde.Dataset.from_json_string(json_str)
6697        return cde.Dataset.from_json_file(self.input_obj)
6698
6699
6700class CityscapesDataset(MappableDataset):
6701    """
6702    A source dataset for reading and parsing Cityscapes dataset.
6703
6704    The generated dataset has two columns :py:obj:`[image, task]`.
6705    The tensor of column :py:obj:`image` is of the uint8 type.
6706    The tensor of column :py:obj:`task` is of the uint8 type if task is not 'polygon' otherwise task is
6707    a string tensor with serialize json.
6708
6709    Args:
6710        dataset_dir (str): Path to the root directory that contains the dataset.
6711        usage (str): Acceptable usages include `train`, `test`, `val` or `all` if quality_mode is `fine`
6712            otherwise `train`, `train_extra`, `val` or `all` (default=`train`).
6713        quality_mode (str): Acceptable quality_modes include `fine` or `coarse` (default=`fine`).
6714        task (str): Acceptable tasks include `instance`, `semantic`, `polygon` or `color` (default=`instance`).
6715        num_samples (int, optional): The number of images to be included in the dataset.
6716            (default=None, all images).
6717        num_parallel_workers (int, optional): Number of workers to read the data
6718            (default=None, number set in the config).
6719        shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
6720            order behavior shown in the table).
6721        decode (bool, optional): Decode the images after reading (default=False).
6722        sampler (Sampler, optional): Object used to choose samples from the
6723            dataset (default=None, expected order behavior shown in the table).
6724        num_shards (int, optional): Number of shards that the dataset will be divided
6725            into (default=None). When this argument is specified, `num_samples` reflects
6726            the max sample number of per shard.
6727        shard_id (int, optional): The shard ID within num_shards (default=None). This
6728            argument can only be specified when num_shards is also specified.
6729        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
6730            (default=None, which means no cache is used).
6731
6732    Raises:
6733        RuntimeError: If dataset_dir is invalid or does not contain data files.
6734        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
6735        RuntimeError: If sampler and shuffle are specified at the same time.
6736        RuntimeError: If sampler and sharding are specified at the same time.
6737        RuntimeError: If num_shards is specified but shard_id is None.
6738        RuntimeError: If shard_id is specified but num_shards is None.
6739        ValueError: If dataset_dir is not exist.
6740        ValueError: If task is invalid.
6741        ValueError: If quality_mode is invalid.
6742        ValueError: If usage is invalid.
6743        ValueError: If shard_id is invalid (< 0 or >= num_shards).
6744
6745    Note:
6746        - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
6747          The table below shows what input arguments are allowed and their expected behavior.
6748
6749    .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
6750       :widths: 25 25 50
6751       :header-rows: 1
6752
6753       * - Parameter `sampler`
6754         - Parameter `shuffle`
6755         - Expected Order Behavior
6756       * - None
6757         - None
6758         - random order
6759       * - None
6760         - True
6761         - random order
6762       * - None
6763         - False
6764         - sequential order
6765       * - Sampler object
6766         - None
6767         - order defined by sampler
6768       * - Sampler object
6769         - True
6770         - not allowed
6771       * - Sampler object
6772         - False
6773         - not allowed
6774
6775    Examples:
6776        >>> cityscapes_dataset_dir = "/path/to/cityscapes_dataset_directory"
6777        >>>
6778        >>> # 1) Get all samples from Cityscapes dataset in sequence
6779        >>> dataset = ds.CityscapesDataset(dataset_dir=cityscapes_dataset_dir, task="instance", quality_mode="fine",
6780        >>>                                usage="train", shuffle=False, num_parallel_workers=1)
6781        >>>
6782        >>> # 2) Randomly select 350 samples from Cityscapes dataset
6783        >>> dataset = ds.CityscapesDataset(dataset_dir=cityscapes_dataset_dir, num_samples=350, shuffle=True,
6784        >>>                                num_parallel_workers=1)
6785        >>>
6786        >>> # 3) Get samples from Cityscapes dataset for shard 0 in a 2-way distributed training
6787        >>> dataset = ds.CityscapesDataset(dataset_dir=cityscapes_dataset_dir, num_shards=2, shard_id=0,
6788        >>>                                num_parallel_workers=1)
6789        >>>
6790        >>> # In Cityscapes dataset, each dictionary has keys "image" and "task"
6791
6792    About Cityscapes dataset:
6793
6794    The Cityscapes dataset consists of 5000 colour images with high quality dense pixel annotations and
6795    19998 colour images with coarser polygonal annotations in 50 cities. There are 30 classes in this
6796    dataset and the polygonal annotations include dense semantic segmentation and instance segmentation
6797    for vehicle and people.
6798
6799    You can unzip the dataset files into the following directory structure and read by MindSpore's API.
6800
6801    Taking the quality_mode of `fine` as an example.
6802
6803    .. code-block::
6804
6805        .
6806        └── Cityscapes
6807             ├── leftImg8bit
6808             |    ├── train
6809             |    |    ├── aachen
6810             |    |    |    ├── aachen_000000_000019_leftImg8bit.png
6811             |    |    |    ├── aachen_000001_000019_leftImg8bit.png
6812             |    |    |    ├── ...
6813             |    |    ├── bochum
6814             |    |    |    ├── ...
6815             |    |    ├── ...
6816             |    ├── test
6817             |    |    ├── ...
6818             |    ├── val
6819             |    |    ├── ...
6820             └── gtFine
6821                  ├── train
6822                  |    ├── aachen
6823                  |    |    ├── aachen_000000_000019_gtFine_color.png
6824                  |    |    ├── aachen_000000_000019_gtFine_instanceIds.png
6825                  |    |    ├── aachen_000000_000019_gtFine_labelIds.png
6826                  |    |    ├── aachen_000000_000019_gtFine_polygons.json
6827                  |    |    ├── aachen_000001_000019_gtFine_color.png
6828                  |    |    ├── aachen_000001_000019_gtFine_instanceIds.png
6829                  |    |    ├── aachen_000001_000019_gtFine_labelIds.png
6830                  |    |    ├── aachen_000001_000019_gtFine_polygons.json
6831                  |    |    ├── ...
6832                  |    ├── bochum
6833                  |    |    ├── ...
6834                  |    ├── ...
6835                  ├── test
6836                  |    ├── ...
6837                  └── val
6838                       ├── ...
6839
6840    Citation:
6841
6842    .. code-block::
6843
6844        @inproceedings{Cordts2016Cityscapes,
6845        title       = {The Cityscapes Dataset for Semantic Urban Scene Understanding},
6846        author      = {Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler,
6847                        Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt},
6848        booktitle   = {Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
6849        year        = {2016}
6850        }
6851    """
6852
6853    @check_cityscapes_dataset
6854    def __init__(self, dataset_dir, usage="train", quality_mode="fine", task="instance", num_samples=None,
6855                 num_parallel_workers=None, shuffle=None, decode=None, sampler=None, num_shards=None,
6856                 shard_id=None, cache=None):
6857        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
6858                         shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
6859
6860        self.dataset_dir = dataset_dir
6861        self.task = task
6862        self.quality_mode = quality_mode
6863        self.usage = usage
6864        self.decode = replace_none(decode, False)
6865
6866    def parse(self, children=None):
6867        return cde.CityscapesNode(self.dataset_dir, self.usage, self.quality_mode, self.task, self.decode, self.sampler)
6868
6869
6870class DIV2KDataset(MappableDataset):
6871    """
6872    A source dataset for reading and parsing DIV2KDataset dataset.
6873
6874    The generated dataset has two columns :py:obj:`[hr_image, lr_image]`.
6875    The tensor of column :py:obj:`hr_image` is of the uint8 type.
6876    The tensor of column :py:obj:`lr_image` is of the uint8 type.
6877
6878    Args:
6879        dataset_dir (str): Path to the root directory that contains the dataset.
6880        usage (str): Acceptable usages include `train`, `valid` or `all` (default=`train`).
6881        downgrade (str): Acceptable downgrades include `bicubic`, `unknown`, `mild`, `difficult` or
6882            `wild` (default=`bicubic`).
6883        scale (int): Acceptable scales include 2, 3, 4 or 8 (default=2).
6884            When `downgrade` is `bicubic`, scale can be 2, 3, 4, 8.
6885            When `downgrade` is `unknown`, scale can only be 2, 3, 4.
6886            When `downgrade` is `mild`, `difficult` or `wild`, scale can only be 4.
6887        num_samples (int, optional): The number of images to be included in the dataset.
6888            (default=None, all images).
6889        num_parallel_workers (int, optional): Number of workers to read the data
6890            (default=None, number set in the config).
6891        shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
6892            order behavior shown in the table).
6893        decode (bool, optional): Decode the images after reading (default=False).
6894        sampler (Sampler, optional): Object used to choose samples from the
6895            dataset (default=None, expected order behavior shown in the table).
6896        num_shards (int, optional): Number of shards that the dataset will be divided
6897            into (default=None). When this argument is specified, `num_samples` reflects
6898            the max sample number of per shard.
6899        shard_id (int, optional): The shard ID within num_shards (default=None). This
6900            argument can only be specified when num_shards is also specified.
6901        cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
6902            (default=None, which means no cache is used).
6903
6904    Raises:
6905        RuntimeError: If dataset_dir is invalid or does not contain data files.
6906        RuntimeError: If num_parallel_workers exceeds the max thread numbers.
6907        RuntimeError: If sampler and shuffle are specified at the same time.
6908        RuntimeError: If sampler and sharding are specified at the same time.
6909        RuntimeError: If num_shards is specified but shard_id is None.
6910        RuntimeError: If shard_id is specified but num_shards is None.
6911        ValueError: If dataset_dir is not exist.
6912        ValueError: If usage is invalid.
6913        ValueError: If downgrade is invalid.
6914        ValueError: If scale is invalid.
6915        ValueError: If scale equal to 8 and downgrade not equal to `bicubic`.
6916        ValueError: If downgrade in [`mild`, `difficult`, `wild`] and scale not equal to 4.
6917        ValueError: If shard_id is invalid (< 0 or >= num_shards).
6918
6919    Note:
6920        - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
6921          The table below shows what input arguments are allowed and their expected behavior.
6922
6923    .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
6924       :widths: 25 25 50
6925       :header-rows: 1
6926
6927       * - Parameter `sampler`
6928         - Parameter `shuffle`
6929         - Expected Order Behavior
6930       * - None
6931         - None
6932         - random order
6933       * - None
6934         - True
6935         - random order
6936       * - None
6937         - False
6938         - sequential order
6939       * - Sampler object
6940         - None
6941         - order defined by sampler
6942       * - Sampler object
6943         - True
6944         - not allowed
6945       * - Sampler object
6946         - False
6947         - not allowed
6948
6949    Examples:
6950        >>> div2k_dataset_dir = "/path/to/div2k_dataset_directory"
6951        >>>
6952        >>> # 1) Get all samples from DIV2K dataset in sequence
6953        >>> dataset = ds.DIV2KDataset(dataset_dir=div2k_dataset_dir, usage="train", scale=2, downgrade="bicubic",
6954        >>>                           shuffle=False)
6955        >>>
6956        >>> # 2) Randomly select 350 samples from DIV2K dataset
6957        >>> dataset = ds.DIV2KDataset(dataset_dir=div2k_dataset_dir, usage="train", scale=2, downgrade="bicubic",
6958        >>>                           num_samples=350, shuffle=True)
6959        >>>
6960        >>> # 3) Get samples from DIV2K dataset for shard 0 in a 2-way distributed training
6961        >>> dataset = ds.DIV2KDataset(dataset_dir=div2k_dataset_dir, usage="train", scale=2, downgrade="bicubic",
6962        >>>                           num_shards=2, shard_id=0)
6963        >>>
6964        >>> # In DIV2K dataset, each dictionary has keys "hr_image" and "lr_image"
6965
6966    About DIV2K dataset:
6967
6968    The DIV2K dataset consists of 1000 2K resolution images, among which 800 images are for training, 100 images
6969    are for validation and 100 images are for testing. NTIRE 2017 and NTIRE 2018 include only training dataset
6970    and validation dataset.
6971
6972    You can unzip the dataset files into the following directory structure and read by MindSpore's API.
6973
6974    Take the training set as an example.
6975
6976    .. code-block::
6977
6978        .
6979        └── DIV2K
6980             ├── DIV2K_train_HR
6981             |    ├── 0001.png
6982             |    ├── 0002.png
6983             |    ├── ...
6984             ├── DIV2K_train_LR_bicubic
6985             |    ├── X2
6986             |    |    ├── 0001x2.png
6987             |    |    ├── 0002x2.png
6988             |    |    ├── ...
6989             |    ├── X3
6990             |    |    ├── 0001x3.png
6991             |    |    ├── 0002x3.png
6992             |    |    ├── ...
6993             |    └── X4
6994             |         ├── 0001x4.png
6995             |         ├── 0002x4.png
6996             |         ├── ...
6997             ├── DIV2K_train_LR_unknown
6998             |    ├── X2
6999             |    |    ├── 0001x2.png
7000             |    |    ├── 0002x2.png
7001             |    |    ├── ...
7002             |    ├── X3
7003             |    |    ├── 0001x3.png
7004             |    |    ├── 0002x3.png
7005             |    |    ├── ...
7006             |    └── X4
7007             |         ├── 0001x4.png
7008             |         ├── 0002x4.png
7009             |         ├── ...
7010             ├── DIV2K_train_LR_mild
7011             |    ├── 0001x4m.png
7012             |    ├── 0002x4m.png
7013             |    ├── ...
7014             ├── DIV2K_train_LR_difficult
7015             |    ├── 0001x4d.png
7016             |    ├── 0002x4d.png
7017             |    ├── ...
7018             ├── DIV2K_train_LR_wild
7019             |    ├── 0001x4w.png
7020             |    ├── 0002x4w.png
7021             |    ├── ...
7022             └── DIV2K_train_LR_x8
7023                  ├── 0001x8.png
7024                  ├── 0002x8.png
7025                  ├── ...
7026    Citation:
7027
7028    .. code-block::
7029
7030        @InProceedings{Agustsson_2017_CVPR_Workshops,
7031        author    = {Agustsson, Eirikur and Timofte, Radu},
7032        title     = {NTIRE 2017 Challenge on Single Image Super-Resolution: Dataset and Study},
7033        booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
7034        url       = "http://www.vision.ee.ethz.ch/~timofter/publications/Agustsson-CVPRW-2017.pdf",
7035        month     = {July},
7036        year      = {2017}
7037        }
7038    """
7039
7040    @check_div2k_dataset
7041    def __init__(self, dataset_dir, usage="train", downgrade="bicubic", scale=2, num_samples=None,
7042                 num_parallel_workers=None, shuffle=None, decode=None, sampler=None, num_shards=None,
7043                 shard_id=None, cache=None):
7044        super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
7045                         shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
7046
7047        self.dataset_dir = dataset_dir
7048        self.usage = usage
7049        self.scale = scale
7050        self.downgrade = downgrade
7051        self.decode = replace_none(decode, False)
7052
7053    def parse(self, children=None):
7054        return cde.DIV2KNode(self.dataset_dir, self.usage, self.downgrade, self.scale, self.decode, self.sampler)
7055