1# Copyright 2022-2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 161. This file is an abstraction of the dataset loading class. It contains 17some basic dataset operations(skip, filter, map, batch, ...). 182. Specific dataset loading classes can be found in datasets_vision.py, datasets_text.py, 19datasets_audio.py, datasets_standard_format.py and dataets_user_defined.py files. 20 datasets_vision.py: contains vision dataset loading classes. 21 datasets_text.py: contains text dataset loading classes. 22 datasets_audio.py: contains audio dataset loading classes. 23 datasets_standard_format.py: contains standard format loading classes which 24 any other kinds of datasets can be converted to. 25 dataets_user_defined.py: contains basic classes that help users to define 26 flexible ways to load dataset. 27""" 28import atexit 29import glob 30import json 31import os 32import queue 33import signal 34import stat 35import subprocess 36import warnings 37 38import gc 39import time 40import uuid 41import multiprocessing 42from enum import Enum 43from importlib import import_module 44import sys 45import threading 46 47import copy 48import weakref 49import platform 50import psutil 51 52import mindspore._c_dataengine as cde 53from mindspore._c_expression import typing 54 55from mindspore import log as logger 56from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _get_ps_context,\ 57 _enable_distributed_mindrt 58from mindspore.dataset.engine.offload import GetOffloadModel 59 60import mindspore.dataset.transforms.c_transforms as c_transforms 61import mindspore.dataset.transforms.py_transforms as py_transforms 62import mindspore.dataset.transforms as transforms 63from mindspore.dataset.text.utils import SentencePieceModel, DE_C_INTER_SENTENCEPIECE_MODE 64from mindspore.parallel._utils import _get_device_num 65from mindspore.dataset.debug import DebugHook 66 67from mindspore.dataset.engine import samplers 68from .iterators import DictIterator, TupleIterator, DummyIterator, check_iterator_cleanup, _set_iterator_cleanup, \ 69 ITERATORS_LIST, _unset_iterator_cleanup 70from .queue import _SharedQueue, _Queue 71from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ 72 check_rename, check_device_send, check_take, check_output_shape, check_project, \ 73 check_sync_wait, check_zip_dataset, check_add_column, check_concat, check_split, check_bucket_batch_by_length, \ 74 check_save, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_padded_batch, \ 75 check_total_batch 76from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \ 77 get_enable_watchdog, get_seed, set_seed, get_debug_mode, get_multiprocessing_timeout_interval, _get_debug_hook_list 78from ..core.datatypes import mstype_to_detype 79from ..core.validator_helpers import replace_none 80from ..core.py_util_helpers import ExceptionHandler 81from ..transforms.py_transforms_util import FuncWrapper, Implementation 82from ..vision.transforms import ToNumpy 83from ...mindrecord.config import _get_enc_key, _get_enc_mode, _get_hash_mode, encrypt, append_hash_to_file 84 85try: 86 context = import_module("mindspore.context") 87except ModuleNotFoundError: 88 context = None 89 90if platform.system().lower() == "darwin" and multiprocessing.get_start_method() != "fork": 91 multiprocessing.set_start_method("fork", True) 92 93OffloadToManualOffloadMode = { 94 None: cde.ManualOffloadMode.UNSPECIFIED, 95 False: cde.ManualOffloadMode.DISABLED, 96 True: cde.ManualOffloadMode.ENABLED 97} 98 99_train_dataset = None 100 101 102def _set_training_dataset(dataset): 103 """ 104 Set the dataset to be used when training recovery has occurred. 105 106 Args: 107 dataset: the training dataset or iterator 108 """ 109 global _train_dataset 110 _train_dataset = dataset 111 112 113def _get_training_dataset(): 114 """ 115 Get the dataset to be used when training recovery has occurred. 116 117 Returns: 118 training dataset/iterator 119 """ 120 return _train_dataset 121 122 123def _reset_training_dataset(global_step, dataset_size): 124 """ 125 Reset the training dataset to the given global step. 126 127 Args: 128 global_step (int): Number of global steps that have completed training. 129 Dataset will provide data from its next step after reset. 130 dataset_size (int): Number of steps per epoch. 131 """ 132 dataset = _get_training_dataset() 133 if dataset is not None: 134 dataset._reset(global_step, dataset_size) # pylint: disable=protected-access 135 else: 136 raise RuntimeError("Training dataset is not set.") 137 138 139class Shuffle(str, Enum): 140 """Specify the shuffle mode. 141 142 - ``Shuffle.GLOBAL`` : Shuffle both the files and samples. 143 - ``Shuffle.FILES`` : Shuffle files only. 144 - ``Shuffle.INFILE`` : Shuffle data within each file. 145 """ 146 GLOBAL: str = "global" 147 FILES: str = "files" 148 INFILE: str = "infile" 149 150 151ShuffleToShuffleMode = {Shuffle.FILES: cde.ShuffleMode.FILES, 152 Shuffle.GLOBAL: cde.ShuffleMode.GLOBAL, 153 Shuffle.INFILE: cde.ShuffleMode.INFILE} 154 155 156def shuffle_to_shuffle_mode(shuffle): 157 """ 158 Shuffle Enum to Shuffle Mode 159 160 Args: 161 shuffle (Shuffle): shuffle flag to shuffle mode in C layer 162 163 Returns: 164 ShuffleMode, shuffle mode 165 """ 166 shuffle_mode = cde.ShuffleMode.GLOBAL # Global shuffle 167 if not isinstance(shuffle, Shuffle): 168 if shuffle is None or shuffle: 169 shuffle_mode = cde.ShuffleMode.GLOBAL # Global shuffle 170 else: 171 shuffle_mode = cde.ShuffleMode.FALSE # No shuffle 172 else: 173 shuffle_mode = ShuffleToShuffleMode[shuffle] 174 return shuffle_mode 175 176 177def shuffle_to_bool(shuffle): 178 """ 179 Shuffle Enum to bool 180 181 Args: 182 shuffle (Shuffle): shuffle flag to bool 183 184 Returns: 185 bool, True / False 186 """ 187 if shuffle is not None and not isinstance(shuffle, (bool, Shuffle)): 188 raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like 'Shuffle.GLOBAL' or " 189 "'Shuffle.FILES' or 'Shuffle.INFILE'.") 190 191 shuffle_bool = True 192 if not isinstance(shuffle, Shuffle): 193 if shuffle is None: 194 shuffle_bool = None 195 elif shuffle: 196 shuffle_bool = True 197 else: 198 shuffle_bool = False 199 else: 200 shuffle_bool = True 201 return shuffle_bool 202 203 204@check_zip 205def zip(datasets): 206 """ 207 Zip the datasets in the input tuple of datasets. 208 209 Args: 210 datasets (tuple[Dataset]): A tuple of datasets to be zipped together. 211 The number of datasets must be more than 1. 212 213 Returns: 214 Dataset, a new dataset with the above operation applied. 215 216 Raises: 217 ValueError: If the number of datasets is 1. 218 TypeError: If datasets is not a tuple. 219 220 Examples: 221 >>> # Create a dataset which is the combination of dataset_1 and dataset_2 222 >>> import mindspore.dataset as ds 223 >>> 224 >>> dataset_1 = ds.GeneratorDataset([1], "column1") 225 >>> dataset_2 = ds.GeneratorDataset([2], "column2") 226 >>> dataset = ds.zip((dataset_1, dataset_2)) 227 """ 228 if len(datasets) <= 1: 229 raise ValueError( 230 "Can't zip empty or just one dataset!") 231 for dataset in datasets: 232 if not isinstance(dataset, Dataset): 233 raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset)) 234 return ZipDataset(datasets) 235 236 237def _get_operator_process(): 238 """ 239 Inner implemented method, mainly for passing sub-process id in C layer 240 241 Returns: 242 dict, mapping dict of operation id and corresponding process id. 243 """ 244 global _OP_PROCESS 245 process_info = _OP_PROCESS 246 op_process = dict() 247 keys = process_info.keys() 248 fetched_all = True 249 for key in keys: 250 try: 251 op_process[key] = list(process_info[key][1]) 252 item_full = (len(process_info[key][1]) == process_info[key][0]) 253 except KeyError as err: 254 raise err 255 fetched_all = fetched_all and item_full 256 return op_process, fetched_all 257 258 259def _set_dataset_permissions(file_name, num_files): 260 """ 261 set saved dataset files' permissions to 600 262 the rule of dataset filenames should be the same as those in C++. 263 """ 264 num_digits = len(str(num_files - 1)) 265 if num_files == 1: 266 paths = [file_name] 267 else: 268 paths = ["{}{}".format(file_name, str(x).rjust(num_digits, '0')) for x in range(num_files)] 269 270 for item in paths: 271 if os.path.exists(item): 272 os.chmod(item, stat.S_IRUSR | stat.S_IWUSR) 273 index_file = item + ".db" 274 if os.path.exists(index_file): 275 os.chmod(index_file, stat.S_IRUSR | stat.S_IWUSR) 276 277 278class Dataset: 279 """ 280 Abstract class to represent a dataset in DataEngine's data pipeline. 281 282 This class is the base class of SourceDataset and Dataset, and represents 283 a node in the data flow graph. 284 Dataset 285 ----------------------------------------------------------- 286 | | | | 287 VisionBaseDataset TextBaseDataset AudioBaseDataset | 288 - - - | 289 | | | | 290 ---------------------------------------- | 291 UnionBaseDataset | 292 | 293 SourceDataset 294 - 295 | 296 MappableDataset 297 298 DatasetOperation: MapDataset(UnionBaseDataset) 299 BatchDataset(UnionBaseDataset) 300 PaddedBatchDataset(UnionBaseDataset) 301 BucketBatchByLengthDataset(UnionBaseDataset) 302 ShuffleDataset(UnionBaseDataset) 303 FilterDataset(UnionBaseDataset) 304 RepeatDataset(UnionBaseDataset) 305 SkipDataset(UnionBaseDataset) 306 TakeDataset(UnionBaseDataset) 307 ZipDataset(UnionBaseDataset) 308 ConcatDataset(UnionBaseDataset) 309 RenameDataset(UnionBaseDataset) 310 ProjectDataset(UnionBaseDataset) 311 SyncWaitDataset(UnionBaseDataset) 312 313 Impl Dataset - vision: ImageFolderDataset(MappableDataset, VisionBaseDataset) 314 USPSDataset(SourceDataset, VisionBaseDataset) 315 Impl Dataset - text: TextFileDataset(SourceDataset, TextBaseDataset) 316 YahooAnswersDataset(SourceDataset, TextBaseDataset) 317 Impl Dataset - audio: LJSpeechDataset(MappableDataset, AudioBaseDataset) 318 TedliumDataset(MappableDataset, AudioBaseDataset) 319 Impl Dataset - standard: MindDataset(MappableDataset, UnionBaseDataset) 320 TFRecordDataset(SourceDataset, UnionBaseDataset) 321 Impl Dataset - user defined: GeneratorDataset(MappableDataset, UnionBaseDataset) 322 NumpySlicesDataset(GeneratorDataset) 323 324 Args: 325 num_parallel_workers (int, optional): Number of workers to process the dataset in parallel. 326 Default: ``None``. 327 """ 328 329 def __init__(self, children=None, num_parallel_workers=None, cache=None): 330 # Note: children and parent are internal variables, not recommended for external using. 331 self.children = replace_none(children, []) 332 if isinstance(self.children, tuple): 333 self.children = list(self.children) 334 if not isinstance(self.children, list): 335 self.children = [self.children] 336 337 self.parent = [] 338 for child in self.children: 339 child.parent.append(weakref.ref(self)) 340 self.num_parallel_workers = num_parallel_workers 341 self.cache = cache 342 343 self._device_iter = 0 344 self._input_indexs = () 345 self.saved_output_types = None 346 self.saved_output_shapes = None 347 self.estimated_output_shapes = None 348 self.runtime_context = None 349 self._col_names = None 350 self.dataset_size = None 351 self._batch_size = None 352 self._num_classes = None 353 self._repeat_count = None 354 self._class_indexing = None 355 self._sync = False 356 self._global_step = None 357 358 @staticmethod 359 def _get_operator_id(dataset): 360 """ 361 Internal method to iterate the tree and obtain op_id of each operation. 362 363 Returns: 364 Dataset, the root dataset of the tree. 365 """ 366 op_name = dict() 367 generator_process = dict() 368 op_name[str(dataset)] = 0 369 op_id = 1 370 371 def process_name(datasets, operator_id): 372 if not datasets: 373 return 0 374 temp = [] 375 for item in datasets: 376 for d in item.children: 377 temp.append(d) 378 op_name[str(d)] = operator_id 379 380 from mindspore.dataset.engine.datasets_user_defined import GeneratorDataset 381 if isinstance(d, GeneratorDataset) and d.sample_fn and d.sample_fn.pids: 382 generator_process[operator_id] = [d.num_parallel_workers, set(d.sample_fn.pids)] 383 384 operator_id = operator_id + 1 385 return process_name(temp, operator_id) 386 387 process_name([dataset], op_id) 388 if generator_process: 389 global _OP_PROCESS 390 _OP_PROCESS.update(generator_process) 391 return op_name 392 393 def create_ir_tree(self, getter_mode=False): 394 """ 395 Internal method to build an IR tree. 396 397 Args: 398 getter_mode (bool, optional): Whether to build IR tree in pull mode. Default: ``False``. 399 400 Returns: 401 Union[DatasetNode, Dataset], the root node of the IR tree and the root dataset of the IR tree. 402 """ 403 parent = self.parent 404 self.parent = [] 405 dataset = copy.deepcopy(self) 406 global _OP_NAME 407 _OP_NAME = Dataset._get_operator_id(dataset) 408 ir_tree = dataset.parse_tree(getter_mode) 409 self.parent = parent 410 _init_device_info() 411 return ir_tree, dataset 412 413 def parse_tree(self, getter_mode=False): 414 """ 415 Internal method to parse the API tree into an IR tree. 416 417 Args: 418 getter_mode (bool, optional): Whether to build IR tree in pull mode. Default: ``False``. 419 420 Returns: 421 DatasetNode, the root node of the IR tree. 422 """ 423 if len(self.parent) > 1: 424 raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") 425 ir_children = [d.parse_tree(getter_mode) for d in self.children] 426 # Bootstrap can only be performed on a copy of the original dataset node. 427 # Bootstrap on original dataset node will make all iterators share the same process pool 428 self.pre_parse(getter_mode) 429 self.iterator_bootstrap() 430 ir_node = self.parse(ir_children) 431 ir_node = self.post_parse(ir_node) 432 return ir_node 433 434 def __safe_deepcopy__(self, memodict, exclude=()): 435 if id(self) in memodict: 436 return memodict[id(self)] 437 cls = self.__class__ 438 new_op = cls.__new__(cls) 439 memodict[id(self)] = new_op 440 for arg, value in self.__dict__.items(): 441 if arg in exclude: 442 setattr(new_op, arg, value) 443 else: 444 try: 445 setattr(new_op, arg, copy.deepcopy(value, memodict)) 446 except TypeError: 447 setattr(new_op, arg, value) 448 return new_op 449 450 @staticmethod 451 def _noop_mode(): 452 if _is_role_sched(): 453 return True 454 return False 455 456 def iterator_bootstrap(self): 457 pass 458 459 def __add__(self, datasets): 460 return self.concat(datasets) 461 462 def to_json(self, filename=""): 463 """ 464 Serialize a pipeline into JSON string and dump into file if filename is provided. 465 466 Args: 467 filename (str): filename of JSON file to be saved as. Default: ``""``. 468 469 Returns: 470 str, JSON string of the pipeline. 471 472 Examples: 473 >>> import mindspore.dataset as ds 474 >>> mnist_dataset_dir = "/path/to/mnist_dataset_directory" 475 >>> dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir) 476 >>> dataset_json = dataset.to_json("/path/to/mnist_dataset_pipeline.json") 477 """ 478 ir_tree, _ = self.create_ir_tree() 479 return json.loads(ir_tree.to_json(filename)) 480 481 @check_bucket_batch_by_length 482 def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes, element_length_function=None, 483 pad_info=None, pad_to_bucket_boundary=False, drop_remainder=False): 484 """ 485 Bucket elements according to their lengths. Each bucket will be padded and batched when 486 they are full. 487 488 A length function is called on each row in the dataset. The row is then 489 bucketed based on its length and bucket boundaries. When a bucket reaches its 490 corresponding size specified in bucket_batch_sizes, the entire bucket will be 491 padded according to pad_info, and then form a batch. 492 493 Refer to the following figure for the execution process: 494 495 .. image:: bucket_batch_by_length_en.png 496 497 Args: 498 column_names (list[str]): Columns passed to element_length_function. 499 bucket_boundaries (list[int]): A list consisting of the upper boundaries 500 of the buckets. Must be strictly increasing. If there are n boundaries, 501 n+1 buckets are created: One bucket for [0, bucket_boundaries[0]), one 502 bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each 503 0<i<n-1, and the last bucket for [bucket_boundaries[n-1], inf). 504 bucket_batch_sizes (list[int]): A list consisting of the batch sizes for 505 each bucket. Must contain len(bucket_boundaries)+1 elements. 506 element_length_function (Callable, optional): A function that takes in 507 M arguments where M = len(column_names) and returns an integer. If no value 508 provided, parameter M the len(column_names) must be 1, and the size of the first 509 dimension of that column will be taken as the length. Default: ``None``. 510 pad_info (dict, optional): The information about how to batch each column. The key 511 corresponds to the column name, and the value must be a tuple of 2 elements. 512 The first element corresponds to the shape to pad to, and the second 513 element corresponds to the value to pad with. If a column is not 514 specified, then that column will be padded to the longest in the current 515 batch, and 0 will be used as the padding value. Any None dimensions will 516 be padded to the longest in the current batch, unless if 517 `pad_to_bucket_boundary` is ``True``. If no padding is wanted, set `pad_info` 518 to ``None``. Default: ``None``. 519 pad_to_bucket_boundary (bool, optional): If ``True``, will pad each None 520 dimension in `pad_info` to the bucket_boundary minus 1. If there are any 521 elements that fall into the last bucket, an error will occur. 522 Default: ``False``. 523 drop_remainder (bool, optional): If ``True``, will drop the last batch for each 524 bucket if it is not a full batch. Default: ``False``. 525 526 Returns: 527 Dataset, a new dataset with the above operation applied. 528 529 Examples: 530 >>> # Create a dataset where certain counts rows are combined into a batch 531 >>> # and drops the last incomplete batch if there is one. 532 >>> import mindspore.dataset as ds 533 >>> import numpy as np 534 >>> def generate_2_columns(n): 535 ... for i in range(n): 536 ... yield (np.array([i]), np.array([j for j in range(i + 1)])) 537 >>> 538 >>> column_names = ["col1", "col2"] 539 >>> dataset = ds.GeneratorDataset(generate_2_columns(8), column_names) 540 >>> bucket_boundaries = [5, 10] 541 >>> bucket_batch_sizes = [2, 1, 1] 542 >>> element_length_function = (lambda col1, col2: max(len(col1), len(col2))) 543 >>> # Will pad col2 to shape [bucket_boundaries[i]] where i is the 544 >>> # index of the bucket that is currently being batched. 545 >>> pad_info = {"col2": ([None], -1)} 546 >>> pad_to_bucket_boundary = True 547 >>> dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, 548 ... bucket_batch_sizes, 549 ... element_length_function, pad_info, 550 ... pad_to_bucket_boundary) 551 """ 552 return BucketBatchByLengthDataset(self, column_names, bucket_boundaries, bucket_batch_sizes, 553 element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder) 554 555 @check_batch 556 def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, **kwargs): 557 """ 558 Combine batch_size number of consecutive rows into batch which apply per_batch_map to the samples first. 559 560 For any column, all the elements within that column must have the same shape. 561 562 Refer to the following figure for the execution process: 563 564 .. image:: batch_en.png 565 566 Note: 567 The order of using repeat and batch reflects the number of batches and per_batch_map. 568 It is recommended that the repeat operation applied after the batch operation finished. 569 570 Args: 571 batch_size (Union[int, Callable]): The number of rows each batch is created with. An 572 int or callable object which takes exactly 1 parameter, BatchInfo. 573 drop_remainder (bool, optional): Determines whether or not to drop the last block 574 whose data row number is less than batch size. Default: ``False`` . If ``True`` , 575 and if there are less than `batch_size` rows available to make the last batch, 576 then those rows will be dropped and not propagated to the child node. 577 num_parallel_workers (int, optional): Number of workers(threads) to process the dataset in parallel. 578 Default: ``None`` . 579 **kwargs: 580 581 - per_batch_map (Callable[[List[numpy.ndarray], ..., List[numpy.ndarray], BatchInfo], \ 582 (List[numpy.ndarray], ..., List[numpy.ndarray])], optional): Per batch map callable. 583 Default: ``None``. 584 A callable which takes (List[numpy.ndarray], ..., List[numpy.ndarray], BatchInfo) as input parameters. 585 Each list[numpy.ndarray] represents a batch of numpy.ndarray on a given column. The number of lists 586 should match with the number of entries in input_columns. The last parameter of the callable should 587 always be a BatchInfo object. Per_batch_map should return 588 (list[numpy.ndarray], list[numpy.ndarray], ...). The length of each list in output should be the same 589 as the input. output_columns is required if the number of output lists is different from input. 590 591 - input_columns (Union[str, list[str]], optional): List of names of the input columns. The size of 592 the list should match with signature of `per_batch_map` callable. Default: ``None`` . 593 594 - output_columns (Union[str, list[str]], optional): List of names assigned to the columns 595 outputted by the last operation. This parameter is mandatory if len(input_columns) != 596 len(output_columns). The size of this list must match the number of output 597 columns of the last operation. Default: ``None`` , output columns will have the same 598 name as the input columns, i.e., the columns will be replaced. 599 600 - python_multiprocessing (bool, optional): Parallelize Python function `per_batch_map` with 601 multi-processing or multi-threading mode, ``True`` means multi-processing, 602 ``False`` means multi-threading If `per_batch_map` is a I/O bound task, use 603 multi-threading mode. If `per_batch_map` is a CPU bound task, it is recommended to use 604 multi-processing mode. Default: ``False`` , use python multi-threading mode. 605 606 - max_rowsize(Union[int, list[int]], optional): Maximum size of row in MB that is used for shared memory 607 allocation to copy data between processes, the total occupied shared memory will increase as 608 ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set 609 to -1, shared memory will be dynamically allocated with the actual size of data. This is only used if 610 ``python_multiprocessing`` is set to True. If it is an int value, it represents 611 ``input_columns`` and ``output_columns`` use this value as the unit to create shared memory. 612 If it is a list, the first element represents the ``input_columns`` use this value as the unit to 613 create shared memory, and the second element represents ``output_columns`` use this value as the unit 614 to create shared memory. Default: 16. 615 616 Returns: 617 Dataset, a new dataset with the above operation applied. 618 619 Examples: 620 >>> # 1) Create a dataset where every 5 rows are combined into a batch 621 >>> # and drops the last incomplete batch if there is one. 622 >>> import mindspore.dataset as ds 623 >>> from PIL import Image 624 >>> 625 >>> cifar10_dataset_dir = "/path/to/cifar10_dataset_directory" 626 >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, num_samples=10) 627 >>> dataset = dataset.batch(5, True) 628 >>> 629 >>> # 2) resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25) 630 >>> def np_resize(col, BatchInfo): 631 ... output = col.copy() 632 ... s = (BatchInfo.get_batch_num() + 1) ** 2 633 ... index = 0 634 ... for c in col: 635 ... img = Image.fromarray(c.astype('uint8')).convert('RGB') 636 ... img = img.resize((s, s)) 637 ... output[index] = np.array(img) 638 ... index += 1 639 ... return (output,) 640 >>> dataset = dataset.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize) 641 >>> 642 >>> # 3) Create a dataset where its batch size is dynamic 643 >>> # Define a callable batch size function and let batch size increase 1 each time. 644 >>> def add_one(BatchInfo): 645 ... return BatchInfo.get_batch_num() + 1 646 >>> dataset = dataset.batch(batch_size=add_one, drop_remainder=True) 647 """ 648 return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, **kwargs) 649 650 @check_padded_batch 651 def padded_batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, pad_info=None): 652 """ 653 Combine batch_size number of consecutive rows into batch which apply pad_info to the samples first. 654 655 Refer to the following figure for the execution process: 656 657 .. image:: padded_batch_en.png 658 659 Note: 660 The order of using repeat and padded_batch reflects the number of batches. 661 It is recommended that the repeat operation applied after the padded_batch operation finished. 662 663 Args: 664 batch_size (Union[int, Callable]): The number of rows each batch is created with. An 665 int or callable object which takes exactly 1 parameter, BatchInfo. 666 drop_remainder (bool, optional): Determines whether or not to drop the last block 667 whose data row number is less than batch size. Default: ``False``. If ``True``, and if there 668 are less than batch_size rows available to make the last batch, then those rows will 669 be dropped and not propagated to the child node. 670 num_parallel_workers (int, optional): Number of workers(threads) to process the dataset in parallel. 671 Default: ``None``. 672 pad_info (dict, optional): The pad information about how to batch each column. The key 673 corresponds to the column name, and the value must be a tuple of 2 elements. 674 The first element corresponds to the shape to pad to, and the second 675 element corresponds to the value to pad with. If a column is not 676 specified, then that column will be padded to the longest in the current 677 batch, and 0 will be used as the padding value. If ``pad_info={"col1": ([224, 224], 0)}``, 678 expand the data column named ``col1`` to shape (224, 224), and fill in the missing values with 0. 679 If ``pad_info={}``, all samples in the batch will be filled to the shape with the largest sample 680 in the current batch. If ``pad_info={"col1": (None, 100)}``, all samples in the batch will be filled 681 to the shape with the largest sample in the current batch, and fill in the missing values with 100. 682 If no padding is wanted, set `pad_info` to ``None``. Default: ``None``. 683 684 Returns: 685 Dataset, a new dataset with the above operation applied. 686 687 Examples: 688 >>> # 1) Pad every sample to the largest sample's shape and batch the samples 689 >>> import mindspore.dataset as ds 690 >>> dataset = ds.NumpySlicesDataset([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]], "column1") 691 >>> dataset = dataset.padded_batch(2, True, pad_info={}) 692 >>> 693 >>> # 2) Create a dataset where every 3 rows are combined into a batch 694 >>> # and drops the last incomplete batch if there is one. 695 >>> dataset = ds.NumpySlicesDataset([i for i in range(10)], "column1") 696 >>> dataset = dataset.padded_batch(3, True) 697 >>> 698 >>> # 3) Create a dataset where its batch size is dynamic 699 >>> # Define a callable batch size function and let batch size increase 1 each time. 700 >>> def add_one(BatchInfo): 701 ... return BatchInfo.get_batch_num() + 1 702 >>> dataset = dataset.padded_batch(batch_size=add_one, drop_remainder=True) 703 """ 704 return PaddedBatchDataset(self, batch_size, drop_remainder, num_parallel_workers, pad_info) 705 706 @check_sync_wait 707 def sync_wait(self, condition_name, num_batch=1, callback=None): 708 """ 709 Add a blocking condition to the input Dataset and a synchronize action will be applied. 710 711 Args: 712 condition_name (str): The condition name that is used to toggle sending next row. 713 num_batch (int): the number of batches without blocking at the start of each epoch. 714 Default: ``1``. 715 callback (function): The callback function that will be invoked when sync_update is called. 716 Default: ``None``. 717 718 Returns: 719 Dataset, a new dataset with the above operation applied. 720 721 Raises: 722 RuntimeError: If condition name already exists. 723 724 Examples: 725 >>> import mindspore.dataset as ds 726 >>> import numpy as np 727 >>> def gen(): 728 ... for i in range(100): 729 ... yield (np.array(i),) 730 >>> 731 >>> class Augment: 732 ... def __init__(self, loss): 733 ... self.loss = loss 734 ... 735 ... def preprocess(self, input_): 736 ... return input_ 737 ... 738 ... def update(self, data): 739 ... self.loss = data["loss"] 740 >>> 741 >>> batch_size = 4 742 >>> dataset = ds.GeneratorDataset(gen, column_names=["input"]) 743 >>> 744 >>> aug = Augment(0) 745 >>> dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) 746 >>> dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"]) 747 >>> dataset = dataset.batch(batch_size) 748 >>> count = 0 749 >>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 750 ... assert data["input"][0] == count 751 ... count += batch_size 752 ... data = {"loss": count} 753 ... dataset.sync_update(condition_name="policy", data=data) 754 """ 755 return SyncWaitDataset(self, condition_name, num_batch, callback) 756 757 @check_shuffle 758 def shuffle(self, buffer_size): 759 """ 760 Shuffle the dataset by creating a cache with the size of `buffer_size` . 761 762 1. Make a shuffle buffer that contains the first `buffer_size` rows. 763 2. Randomly select an element from the shuffle buffer to be the next row 764 propagated to the child node. 765 3. Get the next row (if any) from the parent node and put it in the shuffle buffer. 766 4. Repeat steps 2 and 3 until there are no more rows left in the shuffle buffer. 767 768 A random seed can be provided to be used on the first epoch via `dataset.config.set_seed` . In every subsequent 769 epoch, the seed is changed to a new one, randomly generated value. 770 771 Args: 772 buffer_size (int): The size of the buffer (must be larger than 1) for 773 shuffling. Setting `buffer_size` equal to the number of rows in the entire 774 dataset will result in a global shuffle. 775 776 Returns: 777 Dataset, a new dataset with the above operation applied. 778 779 Raises: 780 RuntimeError: If exist sync operations before shuffle. 781 782 Examples: 783 >>> import mindspore.dataset as ds 784 >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1") 785 >>> 786 >>> # Optionally set the seed for fixed randomness 787 >>> ds.config.set_seed(58) 788 >>> 789 >>> # Create a shuffled dataset using a shuffle buffer of size 4 790 >>> dataset = dataset.shuffle(4) 791 """ 792 return ShuffleDataset(self, buffer_size) 793 794 def flat_map(self, func): 795 """ 796 Map `func` to each row in dataset and flatten the result. 797 798 Args: 799 func (function): A function that must take one `numpy.ndarray` as an argument and 800 return a `Dataset` . 801 802 Returns: 803 Dataset, a new dataset with the above operation applied. 804 805 Examples: 806 >>> import mindspore.dataset as ds 807 >>> # 1) flat_map on one column dataset 808 >>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]], shuffle=False) 809 >>> 810 >>> def repeat(array): 811 ... # create a NumpySlicesDataset with the array 812 ... data = ds.NumpySlicesDataset(array, shuffle=False) 813 ... # repeat the dataset twice 814 ... data = data.repeat(2) 815 ... return data 816 >>> 817 >>> dataset = dataset.flat_map(repeat) 818 >>> # [0, 1, 0, 1, 2, 3, 2, 3] 819 >>> 820 >>> # 2) flat_map on multi column dataset 821 >>> dataset = ds.NumpySlicesDataset(([[0, 1], [2, 3]], [[0, -1], [-2, -3]]), shuffle=False) 822 >>> 823 >>> def plus_and_minus(col1, col2): 824 ... # apply different methods on columns 825 ... data = ds.NumpySlicesDataset((col1 + 1, col2 - 1), shuffle=False) 826 ... return data 827 >>> 828 >>> dataset = dataset.flat_map(plus_and_minus) 829 >>> # ([1, 2, 3, 4], [-1, -2, -3, -4]) 830 831 Raises: 832 TypeError: If `func` is not a function. 833 TypeError: If `func` doesn't return a Dataset. 834 """ 835 dataset = None 836 if not hasattr(func, '__call__'): 837 logger.critical("func must be a function.") 838 raise TypeError("func must be a function.") 839 840 for row_data in self.create_tuple_iterator(num_epochs=1, output_numpy=True): 841 if dataset is None: 842 dataset = func(*row_data) 843 else: 844 dataset += func(*row_data) 845 846 if not isinstance(dataset, Dataset): 847 logger.critical("flat_map must return a Dataset object.") 848 raise TypeError("flat_map must return a Dataset object.") 849 return dataset 850 851 @check_map 852 def map(self, operations, input_columns=None, output_columns=None, column_order=None, 853 num_parallel_workers=None, **kwargs): 854 """ 855 Apply each operation in operations to this dataset. 856 857 Each operation will be passed one or more columns from the dataset as input, and one or 858 more columns will be outputted. The first operation will be passed the columns specified 859 in input_columns as input. If there is more than one operation in operations, the outputted 860 columns of the previous operation are used as the input columns for the next operation. 861 862 The columns outputted by the very last operation will be assigned names specified by 863 `output_columns` , and if not specified, the column name of output column is same as that of `input_columns` . 864 865 - If you use transformations ( 866 `vision transform <https://mindspore.cn/docs/en/master/api_python/mindspore.\ 867 dataset.transforms.html#module-mindspore.dataset.vision>`_ , 868 `nlp transform <https://mindspore.cn/docs/en/master/api_python/mindspore.\ 869 dataset.transforms.html#module-mindspore.dataset.text>`_ , 870 `audio transform <https://mindspore.cn/docs/en/master/api_python/mindspore.\ 871 dataset.transforms.html#module-mindspore.dataset.audio>`_ ) 872 provided by mindspore dataset, please use the following parameters: 873 874 .. image:: map_parameter_en.png 875 876 - If you use user-defined transform as PyFunc (Python Func), please use the following parameters: 877 878 .. image:: map_parameter_pyfunc_en.png 879 880 Args: 881 operations (Union[list[TensorOperation], list[functions]]): List of operations to be 882 applied on the dataset. Operations are applied in the order they appear in this list. 883 input_columns (Union[str, list[str]], optional): List of the names of the columns that will be passed to 884 the first operation as input. The size of this list must match the number of 885 input columns expected by the first operation. Default: ``None``, the first 886 operation will be passed however many columns that are required, starting from 887 the first column. 888 output_columns (Union[str, list[str]], optional): List of names assigned to the columns outputted by 889 the last operation. This parameter is mandatory if len(input_columns) != 890 len(output_columns). The size of this list must match the number of output 891 columns of the last operation. Default: ``None``, output columns will have the same 892 name as the input columns, i.e., the columns will be replaced. 893 num_parallel_workers (int, optional): Number of threads used to process the dataset in 894 parallel. Default: ``None``, the value from the configuration will be used. 895 **kwargs: 896 897 - python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker processes. 898 This option could be beneficial if the Python operation is computational heavy. Default: ``False``. 899 900 - max_rowsize (Union[int, list[int]], optional): Maximum size of row in MB that is used for shared 901 memory allocation to copy data between processes, the total occupied shared memory will increase as 902 ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set 903 to -1, shared memory will be dynamically allocated with the actual size of data. This is only used if 904 ``python_multiprocessing`` is set to True. If it is an int value, it represents 905 ``input_columns`` and ``output_columns`` use this value as the unit to create shared memory. 906 If it is a list, the first element represents the ``input_columns`` use this value as the unit to 907 create shared memory, and the second element represents ``output_columns`` use this value as the unit 908 to create shared memory. Default: 16. 909 910 - cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 911 Default: ``None``, which means no cache is used. 912 913 - callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called. 914 Default: ``None``. 915 916 - offload (bool, optional): Flag to indicate whether offload is used. Default: ``None``. 917 918 Note: 919 - Input `operations` accepts TensorOperations defined in mindspore.dataset part, plus user-defined 920 Python functions (PyFuncs). 921 - Do not add network computing operators from mindspore.nn and mindspore.ops or others into this 922 `operations` . 923 924 Returns: 925 Dataset, a new dataset with the above operation applied. 926 927 Examples: 928 >>> import mindspore.dataset as ds 929 >>> import mindspore.dataset.vision as vision 930 >>> # dataset is an instance of Dataset which has 2 columns, "image" and "label". 931 >>> # image is of type bytes type which can be decoded to RGB 932 >>> # label is of type int32 933 >>> cifar10_dataset_dir = "/path/to/cifar10_dataset_directory" 934 >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir) 935 >>> 936 >>> # Define two operations, where each operation accepts 1 input column and outputs 1 column. 937 >>> decode_op = vision.Decode(to_pil=False) 938 >>> random_jitter_op = vision.RandomColorAdjust(brightness=(0.8, 0.8), contrast=(1, 1), 939 ... saturation=(1, 1), hue=(0, 0)) 940 >>> 941 >>> # 1) Simple map example. 942 >>> 943 >>> # Apply decode_op on column "image". 944 >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"]) 945 >>> 946 >>> # Decode and rename column "image" to "decoded_image". 947 >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"], output_columns=["decoded_image"]) 948 >>> 949 >>> # A simple example for user defined python function transform. 950 >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"]) 951 >>> dataset = dataset.map(operations=[(lambda x: x - 1)], input_columns=["data"]) 952 >>> 953 >>> # 2) Map example with more than one operation. 954 >>> 955 >>> # Create a dataset where the images are decoded, then randomly color jittered. 956 >>> # decode_op takes column "image" as input and outputs one column. The column 957 >>> # outputted by decode_op is passed as input to random_jitter_op. 958 >>> # random_jitter_op will output one column. Column "image" will be replaced by 959 >>> # the column outputted by random_jitter_op (the very last operation). All other 960 >>> # columns are unchanged. 961 >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"]) 962 >>> 963 >>> # Rename the column outputted by random_jitter_op to "image_mapped". 964 >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"], 965 ... output_columns=["image_mapped"]) 966 >>> 967 >>> # Map with multiple operations using pyfunc and rename column's name 968 >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"]) 969 >>> dataset = dataset.map(operations=[(lambda x: x * x), (lambda x: x - 1)], input_columns=["data"], 970 ... output_columns=["data_mapped"]) 971 >>> 972 >>> # 3) Example where number of input columns is not equal to number of output columns. 973 >>> 974 >>> # operations[0] is a lambda that takes 2 columns as input and outputs 3 columns. 975 >>> # operations[1] is a lambda that takes 3 columns as input and outputs 1 column. 976 >>> # operations[2] is a lambda that takes 1 column as input and outputs 4 columns. 977 >>> # 978 >>> # Note: The number of output columns of operation[i] must equal the number of 979 >>> # input columns of operation[i+1]. Otherwise, this map call will also result 980 >>> # in an error. 981 >>> operations = [(lambda x, y: (x, x + y, x + y + 1)), 982 ... (lambda x, y, z: x * y * z), 983 ... (lambda x: (x % 2, x % 3, x % 5, x % 7))] 984 >>> dataset = ds.NumpySlicesDataset(data=([[0, 1, 2]], [[3, 4, 5]]), column_names=["x", "y"]) 985 >>> dataset = dataset.map(operations, input_columns=["x", "y"], 986 ... output_columns=["mod2", "mod3", "mod5", "mod7"]) 987 """ 988 if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True: 989 num_parallel_workers = 1 990 logger.warning( 991 "Input 'operations' of 'map' includes network computing operators like in mindspore.nn, mindspore.ops, " 992 "mindspore.numpy module and etc, which do not support multi-thread compiling, recommend to replace it " 993 "with python implemented operator like numpy etc. Here decrease 'num_parallel_workers' into 1.") 994 995 return MapDataset(self, operations, input_columns, output_columns, num_parallel_workers, **kwargs) 996 997 @check_filter 998 def filter(self, predicate, input_columns=None, num_parallel_workers=None): 999 """ 1000 Filter dataset by prediction. 1001 1002 Args: 1003 predicate (callable): Python callable which returns a boolean value. If False then filter the element. 1004 input_columns (Union[str, list[str]], optional): List of names of the input columns. If not provided 1005 or provided with ``None``, the predicate will be applied on all columns in the dataset. 1006 Default: ``None``. 1007 num_parallel_workers (int, optional): Number of workers to process the dataset 1008 in parallel. Default: ``None``. 1009 1010 Returns: 1011 Dataset, a new dataset with the above operation applied. 1012 1013 Examples: 1014 >>> # generator data(0 ~ 19) 1015 >>> # filter the data that greater than or equal to 11 1016 >>> import mindspore.dataset as ds 1017 >>> dataset = ds.GeneratorDataset([i for i in range(20)], "data") 1018 >>> dataset = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"]) 1019 """ 1020 return FilterDataset(self, predicate, input_columns, num_parallel_workers) 1021 1022 @check_repeat 1023 def repeat(self, count=None): 1024 """ 1025 Repeat this dataset `count` times. Repeat infinitely if the `count` is ``None`` or ``-1``. 1026 1027 Note: 1028 The order of using repeat and batch reflects the number of batches. It is recommended that 1029 the repeat operation is used after the batch operation. 1030 1031 Args: 1032 count (int): Number of times the dataset is going to be repeated. Default: ``None``. 1033 1034 Returns: 1035 Dataset, a new dataset with the above operation applied. 1036 1037 Examples: 1038 >>> import mindspore.dataset as ds 1039 >>> 1040 >>> # Create a dataset with 10 elements 1041 >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1") 1042 >>> ori_size = dataset.get_dataset_size() 1043 >>> 1044 >>> # Repeat the dataset 50 times. 1045 >>> dataset = dataset.repeat(50) 1046 >>> repeated_size = dataset.get_dataset_size() 1047 >>> print("ori_size", ori_size, ", repeated_size", repeated_size) 1048 ori_size 10 , repeated_size 500 1049 >>> 1050 >>> # Since the original dataset size is less than batch_size, thus no data is returned 1051 >>> dataset1 = ds.GeneratorDataset([i for i in range(10)], "column1") 1052 >>> dataset1 = dataset1.batch(batch_size=20, drop_remainder=True) 1053 >>> dataset1 = dataset1.repeat(6) 1054 >>> 1055 >>> # Repeat the original dataset to 60 elements, thus 3 batches are returned 1056 >>> dataset2 = ds.GeneratorDataset([i for i in range(10)], "column1") 1057 >>> dataset2 = dataset2.repeat(6) 1058 >>> dataset2 = dataset2.batch(batch_size=20, drop_remainder=True) 1059 >>> print("dataset1 size", dataset1.get_dataset_size(), ", dataset2 size", dataset2.get_dataset_size()) 1060 dataset1 size 0 , dataset2 size 3 1061 """ 1062 return RepeatDataset(self, count) 1063 1064 @check_skip 1065 def skip(self, count): 1066 """ 1067 Skip the first N elements of this dataset. 1068 1069 Args: 1070 count (int): Number of elements in the dataset to be skipped. 1071 1072 Returns: 1073 Dataset, a new dataset with the above operation applied. 1074 1075 Examples: 1076 >>> import mindspore.dataset as ds 1077 >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1") 1078 >>> # Skip first 3 elements of dataset and retain 7 elements. 1079 >>> dataset = dataset.skip(3) 1080 """ 1081 return SkipDataset(self, count) 1082 1083 @check_take 1084 def take(self, count=-1): 1085 """ 1086 Take the first specified number of samples from the dataset. 1087 1088 Args: 1089 count (int, optional): The desired number of samples to take. If the value exceeds 1090 the total number of samples in the dataset, all data will be returned. 1091 Default: ``-1`` , will return all data. 1092 1093 Note: 1094 When there are operations that will change the number of samples of the dataset in 1095 the data pipeline, the location of the `take` operation can change its effect. 1096 For example, `batch` operation will combine the successive samples of the specified 1097 `batch_size` into 1 sample, so `.batch(batch_size).take(1)` will be equivalent to 1098 `.take(batch_size).batch(batch_size)`. 1099 1100 Returns: 1101 Dataset, a new dataset with the above operation applied. 1102 1103 Examples: 1104 >>> import mindspore.dataset as ds 1105 >>> mnist_dataset_dir = "/path/to/mnist_dataset_directory" 1106 >>> dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir) 1107 >>> # Take 50 samples from MNIST dataset. 1108 >>> dataset = dataset.take(50) 1109 """ 1110 return TakeDataset(self, count) 1111 1112 def _get_absolute_split_sizes(self, sizes): 1113 """ 1114 Internal method called by split to calculate absolute split sizes and to 1115 do some error checking after calculating absolute split sizes. 1116 1117 Returns: 1118 int, absolute split sizes of the dataset. 1119 """ 1120 # Call get_dataset_size here and check input here because 1121 # don't want to call this once in check_split and another time in 1122 # here again 1123 dataset_size = self.get_dataset_size() 1124 1125 if dataset_size is None or dataset_size <= 0: 1126 raise RuntimeError("dataset_size is unknown, unable to split.") 1127 1128 if not isinstance(sizes, list): 1129 raise RuntimeError("sizes must be a list.") 1130 1131 all_int = all(isinstance(item, int) for item in sizes) 1132 if all_int: 1133 sizes_sum = sum(sizes) 1134 if sizes_sum != dataset_size: 1135 raise RuntimeError("Sum of split sizes {} is not equal to dataset size {}." 1136 .format(sizes_sum, dataset_size)) 1137 return sizes 1138 1139 absolute_sizes = [] 1140 for item in sizes: 1141 absolute_size = int(round(item * dataset_size)) 1142 if absolute_size == 0: 1143 raise RuntimeError("Split percentage {} is too small.".format(item)) 1144 absolute_sizes.append(absolute_size) 1145 1146 absolute_sizes_sum = sum(absolute_sizes) 1147 1148 # if we still need more rows, give them to the first split. 1149 # if we have too many rows, remove the extras from the first split that has 1150 # enough rows. 1151 size_difference = int(dataset_size - absolute_sizes_sum) 1152 if size_difference > 0: 1153 absolute_sizes[0] += size_difference 1154 else: 1155 for i, _ in enumerate(absolute_sizes): 1156 if absolute_sizes[i] + size_difference > 0: 1157 absolute_sizes[i] += size_difference 1158 break 1159 1160 if sum(absolute_sizes) != dataset_size: 1161 raise RuntimeError("Sum of calculated split sizes {} is not equal to dataset size {}." 1162 .format(absolute_sizes_sum, dataset_size)) 1163 1164 return absolute_sizes 1165 1166 @check_split 1167 def split(self, sizes, randomize=True): 1168 """ 1169 Split the dataset into smaller, non-overlapping datasets. 1170 1171 Args: 1172 sizes (Union[list[int], list[float]]): If a list of integers [s1, s2, …, sn] is 1173 provided, the dataset will be split into n datasets of size s1, size s2, …, size sn 1174 respectively. If the sum of all input sizes does not equal the original dataset size, an 1175 error will throw. 1176 If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1 1177 and must sum to 1, otherwise an error will throw. The dataset will be split into n 1178 Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the 1179 original dataset. 1180 If after rounding: 1181 1182 - Any size equals 0, an error will occur. 1183 - The sum of split sizes < K, the difference of K - sigma(round(fi * k)) will be added to the first 1184 split. 1185 - The sum of split sizes > K, the difference of sigma(round(fi * K)) - K will be removed from the first 1186 large enough split such that it will have at least 1 row after removing the difference. 1187 1188 randomize (bool, optional): Determines whether or not to split the data randomly. Default: ``True``. 1189 If True, the data will be randomly split. Otherwise, each split will be created with 1190 consecutive rows from the dataset. 1191 1192 Note: 1193 1. Dataset cannot be sharded if split is going to be called. 1194 2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead. 1195 Shuffling the dataset may not be deterministic, which means the data in each split 1196 will be different in each epoch. 1197 1198 Returns: 1199 Tuple[Dataset], a tuple of new datasets split from the original one. 1200 1201 Raises: 1202 RuntimeError: If get_dataset_size returns None or is not supported for this dataset. 1203 RuntimeError: If `sizes` is list of integers and sum of all elements in sizes does not 1204 equal the dataset size. 1205 RuntimeError: If `sizes` is list of float and there is a split with size 0 after calculations. 1206 RuntimeError: If the dataset is sharded prior to calling split. 1207 ValueError: If `sizes` is list of float and not all floats are between 0 and 1, or if the 1208 floats don't sum to 1. 1209 1210 Examples: 1211 >>> # Split the data into train part and test part. 1212 >>> import mindspore.dataset as ds 1213 >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1") 1214 >>> train_dataset, test_dataset = dataset.split([0.9, 0.1]) 1215 """ 1216 if self.is_shuffled(): 1217 logger.warning("Dataset is shuffled before split.") 1218 1219 if self.is_sharded(): 1220 raise RuntimeError("Dataset should not be sharded before split.") 1221 1222 absolute_sizes = self._get_absolute_split_sizes(sizes) 1223 splits = [] 1224 rows_to_skip = 0 1225 for size in absolute_sizes: 1226 ds = copy.deepcopy(self) 1227 if randomize: 1228 # want to shuffle the same way every epoch before split 1229 # in alter_tree, shuffle buffer is minimum 10000, so use 10000 here 1230 ds = ds.shuffle(10000) 1231 ds.reshuffle_each_epoch = False 1232 1233 if rows_to_skip > 0: 1234 ds = ds.skip(rows_to_skip) 1235 1236 ds = ds.take(size) 1237 splits.append(ds) 1238 1239 rows_to_skip += size 1240 1241 return tuple(splits) 1242 1243 @check_zip_dataset 1244 def zip(self, datasets): 1245 """ 1246 Zip the datasets in the sense of input tuple of datasets. Columns in the input datasets must have different 1247 name. 1248 1249 Args: 1250 datasets (Union[Dataset, tuple[Dataset]]): A tuple of datasets or a single class Dataset 1251 to be zipped together with this dataset. 1252 1253 Returns: 1254 Dataset, a new dataset with the above operation applied. 1255 1256 Raises: 1257 TypeError: The parameter is not dataset object or tuple of dataset objects. 1258 1259 Examples: 1260 >>> # Create a dataset which is the combination of dataset_1 and dataset_2 1261 >>> import mindspore.dataset as ds 1262 >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1") 1263 >>> dataset_2 = ds.GeneratorDataset([1, 2, 3], "column2") 1264 >>> dataset = dataset_1.zip(dataset_2) 1265 """ 1266 if isinstance(datasets, tuple): 1267 datasets = (self, *datasets) 1268 elif isinstance(datasets, Dataset): 1269 datasets = (self, datasets) 1270 else: 1271 raise TypeError("Invalid datasets, expected Dataset object or tuple of Dataset, but got %s!" % datasets) 1272 return ZipDataset(datasets) 1273 1274 @check_concat 1275 def concat(self, datasets): 1276 """ 1277 Concatenate the dataset objects in the input list. 1278 Performing "+" operation on dataset objects can achieve the same effect. 1279 1280 For a dataset concatenated by many other dataset objects, it returns the data in the order of 1281 datasets passed in. If you want to change the data order(such as random selection from each dataset 1282 instead of in sequence), apply `use_sampler` method on the concatenated dataset object. 1283 Currently `use_sampler` supports `dataset.DistributedSampler` for sharding selection from each dataset 1284 or `dataset.RandomSampler` for random selection from each dataset, see examples below. 1285 1286 Note: 1287 The column name, and rank and type of the column data must be the same in the input datasets. 1288 1289 Args: 1290 datasets (Union[list, Dataset]): A list of datasets or a single class Dataset 1291 to be concatenated together with this dataset. 1292 1293 Returns: 1294 Dataset, a new dataset with the above operation applied. 1295 1296 Examples: 1297 >>> import mindspore.dataset as ds 1298 >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1", shuffle=False) 1299 >>> dataset_2 = ds.GeneratorDataset([4, 5, 6], "column1", shuffle=False) 1300 >>> 1301 >>> # Create a dataset by concatenating dataset_1 and dataset_2 with "+" operator 1302 >>> dataset = dataset_1 + dataset_2 1303 >>> # Create a dataset by concatenating dataset_1 and dataset_2 with concat operation 1304 >>> dataset = dataset_1.concat(dataset_2) 1305 >>> 1306 >>> # Check the data order of dataset 1307 >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1", shuffle=False) 1308 >>> dataset_2 = ds.GeneratorDataset([4, 5, 6], "column1", shuffle=False) 1309 >>> dataset = dataset_1 + dataset_2 1310 >>> result = list(dataset) 1311 >>> # [[Tensor(shape=[], dtype=Int64, value= 1)], [Tensor(shape=[], dtype=Int64, value= 2)], 1312 >>> # [Tensor(shape=[], dtype=Int64, value= 3)], [Tensor(shape=[], dtype=Int64, value= 4)], 1313 >>> # [Tensor(shape=[], dtype=Int64, value= 5)], [Tensor(shape=[], dtype=Int64, value= 6)]] 1314 >>> 1315 >>> # Change the data order of concatenated dataset with sharding selection 1316 >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1", shuffle=False) 1317 >>> dataset_2 = ds.GeneratorDataset([4, 5, 6], "column1", shuffle=False) 1318 >>> dataset = dataset_1.concat(dataset_2) 1319 >>> dataset.use_sampler(ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False)) 1320 >>> result = list(dataset) 1321 >>> # [[Tensor(shape=[], dtype=Int64, value= 2)], [Tensor(shape=[], dtype=Int64, value= 4)], 1322 >>> # [Tensor(shape=[], dtype=Int64, value= 6)]] 1323 >>> 1324 >>> # Change the data order of concatenated dataset with random selection 1325 >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1", shuffle=False) 1326 >>> dataset_2 = ds.GeneratorDataset([4, 5, 6], "column1", shuffle=False) 1327 >>> dataset = dataset_1.concat(dataset_2) 1328 >>> dataset.use_sampler(ds.RandomSampler()) 1329 >>> result = list(dataset) 1330 >>> # [[Tensor(shape=[], dtype=Int64, value= 1)], [Tensor(shape=[], dtype=Int64, value= 4)], 1331 >>> # [Tensor(shape=[], dtype=Int64, value= 2)], [Tensor(shape=[], dtype=Int64, value= 5)], 1332 >>> # [Tensor(shape=[], dtype=Int64, value= 6)], [Tensor(shape=[], dtype=Int64, value= 3)]] 1333 """ 1334 if isinstance(datasets, Dataset): 1335 datasets = [self] + [datasets] 1336 elif isinstance(datasets, list): 1337 datasets = [self] + datasets 1338 else: 1339 raise TypeError("Invalid datasets, expected Dataset object or list of Dataset, but got %s!" % datasets) 1340 return ConcatDataset(datasets) 1341 1342 @check_rename 1343 def rename(self, input_columns, output_columns): 1344 """ 1345 Rename the columns in input datasets. 1346 1347 Args: 1348 input_columns (Union[str, list[str]]): List of names of the input columns. 1349 output_columns (Union[str, list[str]]): List of names of the output columns. 1350 1351 Returns: 1352 Dataset, a new dataset with the above operation applied. 1353 1354 Examples: 1355 >>> import mindspore.dataset as ds 1356 >>> input_columns = ["input_col1", "input_col2", "input_col3"] 1357 >>> output_columns = ["output_col1", "output_col2", "output_col3"] 1358 >>> 1359 >>> # Create a dataset with 3 columns 1360 >>> dataset = ds.GeneratorDataset([(1, 2, 3), (3, 4, 5), (5, 6, 7)], column_names=input_columns) 1361 >>> 1362 >>> # Rename "input_col1" to "output_col1", "input_col2" to "output_col2", "input_col3" to "output_col3" 1363 >>> dataset = dataset.rename(input_columns=input_columns, output_columns=output_columns) 1364 """ 1365 1366 return RenameDataset(self, input_columns, output_columns) 1367 1368 @check_project 1369 def project(self, columns): 1370 """ 1371 The specified columns will be selected from the dataset and passed into 1372 the pipeline with the order specified. The other columns are discarded. 1373 1374 Args: 1375 columns(Union[str, list[str]]): List of names of the columns to project. 1376 1377 Returns: 1378 Dataset, a new dataset with the above operation applied. 1379 1380 Examples: 1381 >>> import mindspore.dataset as ds 1382 >>> # Create a dataset with 3 columns 1383 >>> input_columns = ["column1", "column2", "column3"] 1384 >>> dataset = ds.GeneratorDataset([(1, 2, 3), (3, 4, 5), (5, 6, 7)], column_names=input_columns) 1385 >>> 1386 >>> columns_to_project = ["column3", "column1", "column2"] 1387 >>> # in that order, regardless of the original order of columns. 1388 >>> dataset = dataset.project(columns=columns_to_project) 1389 """ 1390 1391 return ProjectDataset(self, columns) 1392 1393 def apply(self, apply_func): 1394 """ 1395 Apply a function in this dataset. 1396 1397 Args: 1398 apply_func (function): A function that must take one `Dataset` as an argument and 1399 return a preprocessed `Dataset` . 1400 1401 Returns: 1402 Dataset, a new dataset with the above operation applied. 1403 1404 Examples: 1405 >>> import mindspore.dataset as ds 1406 >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1") 1407 >>> 1408 >>> # Declare an apply_func function which returns a Dataset object 1409 >>> def apply_func(data): 1410 ... data = data.batch(2) 1411 ... return data 1412 >>> 1413 >>> # Use apply to call apply_func 1414 >>> dataset = dataset.apply(apply_func) 1415 1416 Raises: 1417 TypeError: If apply_func is not a function. 1418 TypeError: If apply_func doesn't return a Dataset. 1419 """ 1420 1421 if not hasattr(apply_func, '__call__'): 1422 raise TypeError("apply_func must be a function.") 1423 1424 dataset = apply_func(self) 1425 if not isinstance(dataset, Dataset): 1426 raise TypeError("apply_func must return a dataset.") 1427 return dataset 1428 1429 @check_device_send 1430 def device_que(self, send_epoch_end=True, create_data_info_queue=False, queue_name=""): 1431 """ 1432 Return a transferred Dataset that transfers data through a device. 1433 1434 Args: 1435 send_epoch_end (bool, optional): Whether to send end of sequence to device or not. 1436 Default: ``True``. 1437 create_data_info_queue (bool, optional): Whether to create queue which stores 1438 types and shapes of data or not. Default: ``False``. 1439 queue_name (str, optional): Name of queue which connects dataset processing and model 1440 computing. Default: ``""``. 1441 1442 Note: 1443 If device is Ascend, features of data will be transferred one by one. The limitation 1444 of data transmission per time is 256M. 1445 1446 Returns: 1447 Dataset, a new dataset with the above operation applied. 1448 1449 Examples: 1450 >>> import mindspore.dataset as ds 1451 >>> import time 1452 >>> 1453 >>> data = ds.TFRecordDataset('/path/to/TF_FILES', '/path/to/TF_SCHEMA_FILE', shuffle=ds.Shuffle.FILES) 1454 >>> data = data.device_que() 1455 >>> data.send() 1456 >>> time.sleep(0.1) 1457 >>> data.stop_send() 1458 """ 1459 return TransferDataset(self, send_epoch_end, create_data_info_queue, queue_name) 1460 1461 @check_save 1462 def save(self, file_name, num_files=1, file_type='mindrecord'): 1463 """ 1464 Save the dynamic data processed by the dataset pipeline in common dataset format. 1465 Supported dataset formats: ``'mindrecord'`` only. And you can use 1466 :class:`mindspore.dataset.MindDataset` API to read the saved file(s). 1467 1468 Implicit type casting exists when saving data as ``'mindrecord'`` . The transform table shows how to do 1469 type casting. 1470 1471 .. list-table:: Implicit Type Casting when Saving as `mindrecord` 1472 :widths: 25 25 50 1473 :header-rows: 1 1474 1475 * - Type in `dataset` 1476 - Type in `mindrecord` 1477 - Details 1478 * - bool 1479 - int32 1480 - transform to int32 1481 * - int8 1482 - int32 1483 - 1484 * - uint8 1485 - int32 1486 - 1487 * - int16 1488 - int32 1489 - 1490 * - uint16 1491 - int32 1492 - 1493 * - int32 1494 - int32 1495 - 1496 * - uint32 1497 - int64 1498 - 1499 * - int64 1500 - int64 1501 - 1502 * - uint64 1503 - int64 1504 - Maybe reverse 1505 * - float16 1506 - float32 1507 - 1508 * - float32 1509 - float32 1510 - 1511 * - float64 1512 - float64 1513 - 1514 * - string 1515 - string 1516 - Multi-dimensional string not supported 1517 * - bytes 1518 - bytes 1519 - Multi-dimensional bytes not supported 1520 1521 Note: 1522 1. To save the samples in order, set dataset's `shuffle` to ``False`` and `num_files` to ``1``. 1523 2. Before calling the function, do not use batch operation, repeat operation or data augmentation operations 1524 with random attribute in map operation. 1525 3. When array dimension is variable, one-dimensional arrays or 1526 multi-dimensional arrays with variable dimension 0 are supported. 1527 4. MindRecord does not support multi-dimensional string or multi-dimensional bytes. 1528 1529 Args: 1530 file_name (str): Path to dataset file. 1531 num_files (int, optional): Number of dataset files. Default: ``1`` . 1532 file_type (str, optional): Dataset format. Default: ``'mindrecord'`` . 1533 1534 Examples: 1535 >>> import mindspore.dataset as ds 1536 >>> import numpy as np 1537 >>> 1538 >>> def generator_1d(): 1539 ... for i in range(10): 1540 ... yield (np.array([i]),) 1541 >>> 1542 >>> # apply dataset operations 1543 >>> d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False) 1544 >>> d1.save('/path/to/save_file') 1545 """ 1546 if (_get_enc_key() is not None or _get_hash_mode() is not None) and num_files > 1: 1547 raise RuntimeError("When encode mode or hash check is enabled, " + 1548 "the automatic sharding function is unavailable.") 1549 1550 ir_tree, api_tree = self.create_ir_tree() 1551 1552 runtime_context = cde.PythonRuntimeContext() 1553 runtime_context.Init() 1554 consumer = cde.PythonSaveToDisk(file_name, num_files, file_type) 1555 consumer.Init(ir_tree) 1556 runtime_context.AssignConsumer(consumer) 1557 1558 consumer.Save() 1559 1560 if _get_hash_mode() is not None: 1561 append_hash_to_file(file_name) 1562 append_hash_to_file(file_name + ".db") 1563 1564 if _get_enc_key() is not None: 1565 encrypt(file_name, _get_enc_key(), _get_enc_mode()) 1566 encrypt(file_name + ".db", _get_enc_key(), _get_enc_mode()) 1567 1568 _set_dataset_permissions(file_name, num_files) 1569 del api_tree 1570 1571 @check_tuple_iterator 1572 def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True): 1573 """ 1574 Create an iterator over the dataset that yields samples of type list, whose elements are 1575 the data for each column. 1576 1577 Args: 1578 columns (list[str], optional): Specify the output columns and the order. 1579 Default: ``None``, keep all the output columns and their original order. 1580 num_epochs (int, optional): The number of epochs to iterate over the entire dataset. 1581 Default: ``-1`` , the dataset can be iterated indefinitely. 1582 output_numpy (bool, optional): Whether to keep the output data as NumPy ndarray, or 1583 convert it to Tensor. Default: ``False`` . 1584 do_copy (bool, optional): Whether to copy the data when converting output to Tensor, 1585 or reuse the buffer for better performance, only works when `output_numpy` is ``False`` . 1586 Default: ``True`` . 1587 1588 Returns: 1589 Iterator, a dataset iterator that yields samples of type list. 1590 1591 Examples: 1592 >>> import mindspore.dataset as ds 1593 >>> 1594 >>> dataset = ds.GeneratorDataset([i for i in range(10)], "data") 1595 >>> num_epochs = 3 1596 >>> iterator = dataset.create_tuple_iterator(num_epochs=num_epochs) 1597 >>> for epoch in range(num_epochs): 1598 ... for item in iterator: 1599 ... # output is of type tuple 1600 ... print(type(item)) 1601 ... break 1602 ... break 1603 <class 'list'> 1604 """ 1605 if output_numpy is None: 1606 output_numpy = False 1607 1608 if Dataset._noop_mode(): 1609 return DummyIterator(self, 'tuple', output_numpy) 1610 return TupleIterator(self, columns, num_epochs, output_numpy, do_copy) 1611 1612 @check_dict_iterator 1613 def create_dict_iterator(self, num_epochs=-1, output_numpy=False, do_copy=True): 1614 """ 1615 Create an iterator over the dataset that yields samples of type dict, 1616 while the key is the column name and the value is the data. 1617 1618 Args: 1619 num_epochs (int, optional): The number of epochs to iterate over the entire dataset. 1620 Default: ``-1`` , the dataset can be iterated indefinitely. 1621 output_numpy (bool, optional): Whether to keep the output data as NumPy ndarray, or 1622 convert it to Tensor. Default: ``False`` . 1623 do_copy (bool, optional): Whether to copy the data when converting output to Tensor, 1624 or reuse the buffer for better performance, only works when `output_numpy` is ``False`` . 1625 Default: ``True`` . 1626 1627 Returns: 1628 Iterator, a dataset iterator that yields samples of type dict. 1629 1630 Examples: 1631 >>> import mindspore.dataset as ds 1632 >>> 1633 >>> dataset = ds.GeneratorDataset([i for i in range(10)], "data") 1634 >>> num_epochs = 3 1635 >>> iterator = dataset.create_dict_iterator(num_epochs=num_epochs) 1636 >>> for epoch in range(num_epochs): 1637 ... for item in iterator: 1638 ... # output is of type dict 1639 ... print(type(item)) 1640 ... break 1641 ... break 1642 <class 'dict'> 1643 """ 1644 if output_numpy is None: 1645 output_numpy = False 1646 1647 if Dataset._noop_mode(): 1648 return DummyIterator(self, 'dict', output_numpy) 1649 return DictIterator(self, num_epochs, output_numpy, do_copy) 1650 1651 def __iter__(self): 1652 """Create an iterator over the dataset.""" 1653 return self.create_tuple_iterator(num_epochs=1) 1654 1655 @property 1656 def input_indexs(self): 1657 """ 1658 Get the column index, which represents the corresponding relationship between the data column order 1659 and the network when using the sink mode. 1660 1661 Returns: 1662 int, tuple of the input index information. 1663 1664 Examples: 1665 >>> import mindspore.dataset as ds 1666 >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1") 1667 >>> # set input_indexs 1668 >>> dataset.input_indexs = 10 1669 >>> print(dataset.input_indexs) 1670 10 1671 """ 1672 if self._input_indexs != (): 1673 return self._input_indexs 1674 1675 # find input_indexes of children 1676 children_input_index = [child.input_indexs for child in self.children] 1677 1678 # in case of more than one child, return the first input_indexes 1679 for cix in children_input_index: 1680 if cix != (): 1681 return cix 1682 1683 # if all children's input_indexes are () or the node is a leaf 1684 return self._input_indexs 1685 1686 @input_indexs.setter 1687 def input_indexs(self, value): 1688 self._input_indexs = value 1689 1690 def copy_batch_size(self, value): 1691 self._batch_size = value 1692 1693 def _init_tree_getters(self, getter_mode=True): 1694 """ 1695 Get pipeline information. 1696 1697 Args: 1698 getter_mode (bool, optional): Whether to build IR tree in pull mode. Default: ``True``. 1699 """ 1700 ir_tree, api_tree = self.create_ir_tree(getter_mode) 1701 1702 runtime_context = cde.PythonRuntimeContext() 1703 runtime_context.Init() 1704 getter = cde.TreeGetters() 1705 getter.Init(ir_tree) 1706 runtime_context.AssignConsumer(getter) 1707 return getter, runtime_context, api_tree 1708 1709 def __init_size_getter(self): 1710 """ 1711 Get pipeline information. 1712 """ 1713 ir_tree, api_tree = self.create_ir_tree() 1714 1715 runtime_context = cde.PythonRuntimeContext() 1716 runtime_context.Init() 1717 getter = cde.DatasetSizeGetters() 1718 getter.Init(ir_tree) 1719 runtime_context.AssignConsumer(getter) 1720 return getter, runtime_context, api_tree 1721 1722 def get_col_names(self): 1723 """ 1724 Return the names of the columns in dataset. 1725 1726 Returns: 1727 list, list of column names in the dataset. 1728 1729 Examples: 1730 >>> import mindspore.dataset as ds 1731 >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1") 1732 >>> col_names = dataset.get_col_names() 1733 >>> print(col_names) 1734 ['column1'] 1735 1736 """ 1737 if self._col_names is None: 1738 runtime_getter = self._init_tree_getters() 1739 self._col_names = runtime_getter[0].GetColumnNames() 1740 1741 return self._col_names 1742 1743 @check_output_shape 1744 def output_shapes(self, estimate=False): 1745 """ 1746 Get the shapes of output data. 1747 1748 Args: 1749 estimate (bool): If `estimate` is ``False`` , will return the shapes of first data row. 1750 Otherwise, will iterate the whole dataset and return the estimated shapes of data row, 1751 where dynamic shape is marked as None (used in dynamic data shapes scenario). 1752 Default: ``False`` . 1753 1754 Returns: 1755 list, list of shapes of each column. 1756 1757 Examples: 1758 >>> import mindspore.dataset as ds 1759 >>> import numpy as np 1760 >>> 1761 >>> def generator1(): 1762 ... for i in range(1, 100): 1763 ... yield np.ones((16, 83, 83)), np.array([i]) 1764 >>> 1765 >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"]) 1766 >>> output_shapes = dataset.output_shapes() 1767 >>> print(output_shapes) 1768 [[16, 83, 83], [1]] 1769 """ 1770 # cache single shape 1771 if not estimate and self.saved_output_shapes is not None: 1772 return self.saved_output_shapes 1773 # cache estimate shape 1774 if estimate and self.estimated_output_shapes is not None: 1775 return self.estimated_output_shapes 1776 1777 # We have a hang problem when two-level pipeline with multiprocessing, we need to extend the life cycle 1778 # of runtime_context. We found this hang problem only occur on output_types and output_shapes. 1779 runtime_getter = self._init_tree_getters() 1780 self.runtime_context = runtime_getter[1] 1781 api_tree = runtime_getter[2] 1782 output_shapes = runtime_getter[0].GetOutputShapes(estimate) 1783 del api_tree 1784 # Need to terminate the runtime context to avoid the occasional hang problem for 1785 # Python (with multiprocessing enabled) in sink mode. 1786 self.runtime_context.Terminate() 1787 del self.runtime_context 1788 1789 if estimate: 1790 self.estimated_output_shapes = output_shapes 1791 else: 1792 self.saved_output_shapes = output_shapes 1793 return output_shapes 1794 1795 def output_types(self): 1796 """ 1797 Get the types of output data. 1798 1799 Returns: 1800 list, list of data types. 1801 1802 Examples: 1803 >>> import mindspore.dataset as ds 1804 >>> import numpy as np 1805 >>> 1806 >>> def generator1(): 1807 ... for i in range(1, 100): 1808 ... yield np.ones((16, 83, 83)).astype(np.float32), np.array([i]).astype(np.int32) 1809 >>> 1810 >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"]) 1811 >>> output_types = dataset.output_types() 1812 >>> print(output_types) 1813 [dtype('float32'), dtype('int32')] 1814 """ 1815 if self.saved_output_types is None: 1816 runtime_getter = self._init_tree_getters() 1817 # We have a hang problem when two-level pipeline with multiprocessing, we need to extend the life cycle 1818 # of runtime_context. We found this hang problem only occur on output_types and output_shapes. 1819 self.runtime_context = runtime_getter[1] 1820 api_tree = runtime_getter[2] 1821 self.saved_output_types = runtime_getter[0].GetOutputTypes() 1822 del api_tree 1823 # Need to terminate the runtime context to avoid the occasional hang problem for 1824 # Python (with multiprocessing enabled) in sink mode. 1825 self.runtime_context.Terminate() 1826 del self.runtime_context 1827 return self.saved_output_types 1828 1829 def get_dataset_size(self): 1830 """ 1831 Return the number of batches in an epoch. 1832 1833 Returns: 1834 int, number of batches. 1835 1836 Examples: 1837 >>> import mindspore.dataset as ds 1838 >>> import numpy as np 1839 >>> 1840 >>> # A generator return 66 samples 1841 >>> def generator1(): 1842 ... for i in range(66): 1843 ... yield np.ones((16, 83, 83)), np.array([i]) 1844 >>> 1845 >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"]) 1846 >>> dataset_size = dataset.get_dataset_size() 1847 >>> print(dataset_size) 1848 66 1849 """ 1850 if self.dataset_size is None: 1851 runtime_getter = self.__init_size_getter() 1852 self.dataset_size = runtime_getter[0].GetDatasetSize(False) 1853 if self.dataset_size == 0: 1854 logger.warning("Got 0 sample from dataset pipeline, check if drop all data or load dataset fail.") 1855 1856 return self.dataset_size 1857 1858 def num_classes(self): 1859 """ 1860 Get the number of classes in a dataset. 1861 1862 Returns: 1863 int, number of classes. 1864 1865 Examples: 1866 >>> import mindspore.dataset as ds 1867 >>> # Read image files 1868 >>> image_folder_dataset_dir = "/path/to/image_folder_dataset_directory" 1869 >>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir) 1870 >>> # Check how many classes exist in image folder 1871 >>> num_classes = dataset.num_classes() 1872 """ 1873 if self._num_classes is None: 1874 runtime_getter = self._init_tree_getters() 1875 self._num_classes = runtime_getter[0].GetNumClasses() 1876 1877 if self._num_classes == -1: 1878 return None 1879 return self._num_classes 1880 1881 def get_sync_notifiers(self): 1882 if self.children: 1883 return self.children[0].get_sync_notifiers() 1884 return {} 1885 1886 def disable_sync(self): 1887 if self.children: 1888 return self.children[0].disable_sync() 1889 return {} 1890 1891 def is_sync(self): 1892 if self.children: 1893 return self.children[0].is_sync() 1894 return False 1895 1896 def sync_update(self, condition_name, num_batch=None, data=None): 1897 """ 1898 Release a blocking condition and trigger callback with given data. 1899 1900 Args: 1901 condition_name (str): The condition name that is used to toggle sending next row. 1902 num_batch (Union[int, None]): The number of batches (rows) that are released. 1903 When `num_batch` is ``None``, it will default to the number specified by the 1904 `sync_wait` operation. Default: ``None``. 1905 data (Any): The data passed to the callback, user defined. Default: ``None``. 1906 1907 Examples: 1908 >>> import numpy as np 1909 >>> import mindspore.dataset as ds 1910 >>> 1911 >>> def gen(): 1912 ... for i in range(100): 1913 ... yield (np.array(i),) 1914 >>> 1915 >>> class Augment: 1916 ... def __init__(self, loss): 1917 ... self.loss = loss 1918 ... 1919 ... def preprocess(self, input_): 1920 ... return input_ 1921 ... 1922 ... def update(self, data): 1923 ... self.loss = data["loss"] 1924 >>> 1925 >>> batch_size = 10 1926 >>> dataset = ds.GeneratorDataset(gen, column_names=["input"]) 1927 >>> aug = Augment(0) 1928 >>> dataset = dataset.sync_wait(condition_name='', num_batch=1) 1929 >>> dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) 1930 >>> dataset = dataset.batch(batch_size) 1931 >>> 1932 >>> count = 0 1933 >>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 1934 ... count += 1 1935 ... data = {"loss": count} 1936 ... dataset.sync_update(condition_name="", data=data) 1937 """ 1938 if (not isinstance(num_batch, int) and num_batch is not None) or \ 1939 (isinstance(num_batch, int) and num_batch <= 0): 1940 # throwing exception, disable all sync_wait in pipeline 1941 self.disable_sync() 1942 raise RuntimeError("Sync_update batch size can only be positive integer, got : {}.".format(num_batch)) 1943 notifiers_dict = self.get_sync_notifiers() 1944 if not isinstance(condition_name, str): 1945 raise TypeError("Argument condition_name with value {} is not of type str, but got {}." 1946 .format(condition_name, type(condition_name))) 1947 if condition_name not in notifiers_dict: 1948 # throwing exception, disable all sync_wait in pipeline 1949 self.disable_sync() 1950 raise RuntimeError("Condition name not found.") 1951 if num_batch is not None: 1952 num_batch *= self.get_batch_size() 1953 notifiers_dict[condition_name](num_batch, data) 1954 1955 def get_batch_size(self): 1956 """ 1957 Return the size of batch. 1958 1959 Returns: 1960 int, the batch size of data. 1961 1962 Examples: 1963 >>> import mindspore.dataset as ds 1964 >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1") 1965 >>> dataset = dataset.batch(2) 1966 >>> batch_size = dataset.get_batch_size() 1967 >>> print(batch_size) 1968 2 1969 """ 1970 if self._batch_size is None: 1971 runtime_getter = self._init_tree_getters() 1972 self._batch_size = runtime_getter[0].GetBatchSize() 1973 if self._batch_size is None: 1974 self._batch_size = 1 1975 return self._batch_size 1976 1977 def get_repeat_count(self): 1978 """ 1979 Get the replication times in RepeatDataset. Default: ``1`` . 1980 1981 Returns: 1982 int, the count of repeat. 1983 1984 Examples: 1985 >>> import mindspore.dataset as ds 1986 >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1") 1987 >>> dataset = dataset.repeat(5) 1988 >>> repeat_count = dataset.get_repeat_count() 1989 >>> print(repeat_count) 1990 5 1991 """ 1992 if self._repeat_count is None: 1993 runtime_getter = self._init_tree_getters() 1994 self._repeat_count = runtime_getter[0].GetRepeatCount() 1995 if self._repeat_count is None: 1996 self._repeat_count = 1 1997 return self._repeat_count 1998 1999 def get_class_indexing(self): 2000 """ 2001 Get the mapping dictionary from category names to category indexes. 2002 2003 This dictionary can be used to look up which category name corresponds to a particular category index. 2004 2005 Returns: 2006 Dict[str, int], the mappings from category names to category indexes. 2007 2008 Examples: 2009 >>> import mindspore.dataset as ds 2010 >>> # Read image files 2011 >>> image_folder_dataset_dir = "/path/to/image_folder_dataset_directory" 2012 >>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir) 2013 >>> # Check how many classes exist in image folder 2014 >>> class_indexing = dataset.get_class_indexing() 2015 """ 2016 if self.children: 2017 return self.children[0].get_class_indexing() 2018 return {} 2019 2020 def reset(self): 2021 """ 2022 Reset the dataset for next epoch. 2023 2024 Examples: 2025 >>> import mindspore.dataset as ds 2026 >>> mind_dataset_dir = ["/path/to/mind_dataset_file"] 2027 >>> dataset = ds.MindDataset(dataset_files=mind_dataset_dir) 2028 >>> for _ in range(5): 2029 ... num_iter = 0 2030 ... for data in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True): 2031 ... num_iter += 1 2032 ... dataset.reset() 2033 """ 2034 2035 def is_shuffled(self): 2036 """Returns True if the dataset or its children is shuffled.""" 2037 for input_dataset in self.children: 2038 if input_dataset.is_shuffled(): 2039 return True 2040 2041 return False 2042 2043 def is_sharded(self): 2044 """Returns True if the dataset or its children is sharded.""" 2045 for input_dataset in self.children: 2046 if input_dataset.is_sharded(): 2047 return True 2048 2049 return False 2050 2051 def parse(self, children=None): 2052 raise NotImplementedError("Dataset has to implement parse method.") 2053 2054 def __len__(self): 2055 """ 2056 Get the length of dataset. 2057 2058 Returns: 2059 int, the length of dataset. 2060 """ 2061 return self.get_dataset_size() 2062 2063 @staticmethod 2064 def _update_data_shard(num_shards, shard_id): 2065 """ 2066 Update the shard number and shard id if necessary. 2067 This is normally used in distributed training mode like Parameter Server training. 2068 """ 2069 # If this is in distributed execution mode, 2070 # the shard number and shard id might need to be updated according to the process's rank or role. 2071 worker_num = _get_ps_context("worker_num") 2072 server_num = _get_ps_context("server_num") 2073 if _is_role_pserver() and _enable_distributed_mindrt() and (worker_num != server_num): 2074 num_shards = worker_num 2075 shard_id = 0 2076 return num_shards, shard_id 2077 2078 def pre_parse(self, getter_mode): 2079 if getter_mode: 2080 if hasattr(self, "python_multiprocessing"): 2081 self.python_multiprocessing = False 2082 if hasattr(self, "num_parallel_workers"): 2083 self.num_parallel_workers = 1 2084 2085 def post_parse(self, ir_node): 2086 if self.cache: 2087 ir_node = ir_node.set_cache_client(self.cache.cache_client) 2088 if self.num_parallel_workers: 2089 ir_node = ir_node.set_num_workers(self.num_parallel_workers) 2090 2091 return ir_node 2092 2093 def set_init_step(self, init_step): 2094 self._global_step = init_step 2095 2096 def get_init_step(self): 2097 if self._global_step is not None: 2098 return self._global_step 2099 if len(self.children) == 1: 2100 return self.children[0].get_init_step() 2101 # When there are multiple children, we cannot tell from which child to get the initial step, 2102 # so we initialize from the beginning 2103 return 0 2104 2105 2106class VisionBaseDataset(Dataset): 2107 """ 2108 Abstract class to represent a vision source dataset which produces content to the data pipeline. 2109 """ 2110 2111 def __init__(self, children=None, num_parallel_workers=None, cache=None): 2112 super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache) 2113 2114 def parse(self, children=None): 2115 raise NotImplementedError("Dataset has to implement parse method.") 2116 2117 2118class TextBaseDataset(Dataset): 2119 """ 2120 Abstract class to represent a text source dataset which produces content to the data pipeline. 2121 """ 2122 2123 def __init__(self, children=None, num_parallel_workers=None, cache=None): 2124 super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache) 2125 2126 def parse(self, children=None): 2127 raise NotImplementedError("Dataset has to implement parse method.") 2128 2129 def build_vocab(self, columns, freq_range, top_k, special_tokens, special_first): 2130 """ 2131 Function to create a Vocab from source dataset. 2132 Desired source dataset is a text type dataset. 2133 2134 Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab 2135 which contains top_k most frequent words (if top_k is specified). 2136 2137 Note: 2138 mindspore.dataset.Dataset.build_vocab is deprecated from version 2.0 2139 and will be removed in a future version. Use mindspore.dataset.text.Vocab.from_dataset instead. 2140 2141 Args: 2142 columns(Union[str, list[str]]): Column names to get words from. 2143 freq_range(tuple[int]): A tuple of integers (min_frequency, max_frequency). Words within the frequency 2144 range will be stored. 2145 Naturally 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency 2146 can be set to default, which corresponds to 0/total_words separately. 2147 top_k(int): Number of words to be built into vocab. top_k most frequent words are 2148 taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken 2149 special_tokens(list[str]): A list of strings, each one is a special token. 2150 special_first(bool): Whether special_tokens will be prepended/appended to vocab, If special_tokens 2151 is specified and special_first is set to default, special_tokens will be prepended. 2152 2153 Returns: 2154 Vocab, vocab built from the dataset. 2155 """ 2156 warnings.warn("mindspore.dataset.Dataset.build_vocab is deprecated from version 2.0 " 2157 "and will be removed in a future version. " 2158 "Use mindspore.dataset.text.Vocab.from_dataset instead.", DeprecationWarning) 2159 2160 def build_sentencepiece_vocab(self, columns, vocab_size, character_coverage, model_type, params): 2161 """ 2162 Function to create a SentencePieceVocab from source dataset. 2163 Desired source dataset is a text type dataset. 2164 2165 Note: 2166 mindspore.dataset.Dataset.build_sentencepiece_vocab is deprecated from version 2.0 2167 and will be removed in a future version. Use mindspore.dataset.text.SentencePieceVocab.from_dataset instead. 2168 2169 Args: 2170 columns(list[str]): Column names to get words from. 2171 vocab_size(int): Vocabulary size. 2172 character_coverage(float): Percentage of characters covered by the model, must be between 2173 0.98 and 1.0 Good defaults are: 0.9995 for languages with rich character sets like 2174 Japanese or Chinese character sets, and 1.0 for other languages with small character sets 2175 like English or Latin. 2176 model_type(SentencePieceModel): Model type. Choose from unigram (default), bpe, char, or word. 2177 The input sentence must be pretokenized when using word type. 2178 params(dict): Any extra optional parameters of sentencepiece library according to your raw data 2179 2180 Returns: 2181 SentencePieceVocab, vocab built from the dataset. 2182 """ 2183 warnings.warn("mindspore.dataset.Dataset.build_sentencepiece_vocab is deprecated from version 2.0 " 2184 "and will be removed in a future version. " 2185 "Use mindspore.dataset.text.SentencePieceVocab.from_dataset instead.", DeprecationWarning) 2186 2187 def _build_vocab(self, columns, freq_range, top_k, special_tokens, special_first): 2188 """ 2189 Function to create a Vocab from source dataset. 2190 Desired source dataset is a text type dataset. 2191 2192 Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab 2193 which contains top_k most frequent words (if top_k is specified). 2194 2195 Args: 2196 columns(Union[str, list[str]]): Column names to get words from. 2197 freq_range(tuple[int]): A tuple of integers (min_frequency, max_frequency). Words within the frequency 2198 range will be stored. 2199 Naturally 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency 2200 can be set to default, which corresponds to 0/total_words separately. 2201 top_k(int): Number of words to be built into vocab. top_k most frequent words are 2202 taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken 2203 special_tokens(list[str]): A list of strings, each one is a special token. 2204 special_first(bool): Whether special_tokens will be prepended/appended to vocab, If special_tokens 2205 is specified and special_first is set to default, special_tokens will be prepended. 2206 2207 Returns: 2208 Vocab, vocab built from the dataset. 2209 """ 2210 vocab = cde.Vocab() 2211 columns = replace_none(columns, []) 2212 if not isinstance(columns, list): 2213 columns = [columns] 2214 2215 freq_range = replace_none(freq_range, (0, 9223372036854775807)) 2216 if freq_range[0] is None: 2217 freq_range = (0, freq_range[1]) 2218 if freq_range[1] is None: 2219 freq_range = (freq_range[0], 9223372036854775807) 2220 special_tokens = replace_none(special_tokens, []) 2221 top_k = replace_none(top_k, 9223372036854775807) 2222 2223 ir_tree, api_tree = self.create_ir_tree() 2224 2225 # vocab node 2226 vocab_node = cde.BuildVocabNode(ir_tree, vocab, columns, freq_range, top_k, special_tokens, special_first) 2227 2228 runtime_context = cde.PythonRuntimeContext() 2229 runtime_context.Init() 2230 2231 # build vocab 2232 consumer = cde.PythonBuildVocabConsumer() 2233 consumer.Init(vocab_node) 2234 runtime_context.AssignConsumer(consumer) 2235 2236 consumer.Start() 2237 del api_tree 2238 2239 return vocab 2240 2241 def _build_sentencepiece_vocab(self, columns, vocab_size, character_coverage, model_type, params): 2242 """ 2243 Function to create a SentencePieceVocab from source dataset. 2244 Desired source dataset is a text type dataset. 2245 2246 Args: 2247 columns(list[str]): Column names to get words from. 2248 vocab_size(int): Vocabulary size. 2249 character_coverage(float): Percentage of characters covered by the model, must be between 2250 0.98 and 1.0 Good defaults are: 0.9995 for languages with rich character sets like 2251 Japanese or Chinese character sets, and 1.0 for other languages with small character sets 2252 like English or Latin. 2253 model_type(SentencePieceModel): Model type. Choose from unigram (default), bpe, char, or word. 2254 The input sentence must be pretokenized when using word type. 2255 params(dict): Any extra optional parameters of sentencepiece library according to your raw data 2256 2257 Returns: 2258 SentencePieceVocab, vocab built from the dataset. 2259 """ 2260 if not isinstance(model_type, SentencePieceModel): 2261 raise TypeError("Argument model_type with value {0} is not of type SentencePieceModel, but got {1}." \ 2262 .format(model_type, type(model_type))) 2263 model_type = DE_C_INTER_SENTENCEPIECE_MODE[model_type] 2264 vocab = cde.SentencePieceVocab() 2265 2266 ir_tree, api_tree = self.create_ir_tree() 2267 2268 # vocab node 2269 vocab_node = cde.BuildSentenceVocabNode(ir_tree, vocab, columns, vocab_size, character_coverage, model_type, 2270 params) 2271 2272 runtime_context = cde.PythonRuntimeContext() 2273 runtime_context.Init() 2274 2275 # build vocab 2276 consumer = cde.PythonBuildVocabConsumer() 2277 consumer.Init(vocab_node) 2278 runtime_context.AssignConsumer(consumer) 2279 2280 consumer.Start() 2281 del api_tree 2282 2283 return vocab 2284 2285 2286class AudioBaseDataset(Dataset): 2287 """ 2288 Abstract class to represent a audio source dataset which produces content to the data pipeline. 2289 """ 2290 2291 def __init__(self, children=None, num_parallel_workers=None, cache=None): 2292 super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache) 2293 2294 def parse(self, children=None): 2295 raise NotImplementedError("Dataset has to implement parse method.") 2296 2297 2298class UnionBaseDataset(VisionBaseDataset, TextBaseDataset, AudioBaseDataset): 2299 """ 2300 Abstract class to represent a union source dataset which produces content to the data pipeline. 2301 """ 2302 2303 def __init__(self, children=None, num_parallel_workers=None, cache=None): 2304 super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache) 2305 2306 def parse(self, children=None): 2307 raise NotImplementedError("Dataset has to implement parse method.") 2308 2309 2310class SourceDataset(Dataset): 2311 """ 2312 Abstract class to represent a source dataset which produces content to the data pipeline. 2313 """ 2314 2315 def __init__(self, num_parallel_workers=None, num_samples=None, shuffle=True, num_shards=None, shard_id=None, 2316 cache=None): 2317 super().__init__(num_parallel_workers=num_parallel_workers, cache=cache) 2318 self.num_samples = replace_none(num_samples, 0) 2319 self.num_shards = replace_none(num_shards, 1) 2320 self.shard_id = replace_none(shard_id, 0) 2321 2322 if shuffle is not None and not isinstance(shuffle, (bool, Shuffle)): 2323 raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like 'Shuffle.GLOBAL' or " 2324 "'Shuffle.FILES' or 'Shuffle.INFILE'.") 2325 2326 self.shuffle_flag = 2 # Global shuffle 2327 if not isinstance(shuffle, Shuffle): 2328 if shuffle is None or shuffle: 2329 self.shuffle_flag = 2 # Global shuffle 2330 else: 2331 self.shuffle_flag = 0 # No shuffle 2332 else: 2333 if shuffle == Shuffle.GLOBAL: 2334 self.shuffle_flag = 2 # Global shuffle 2335 elif shuffle == Shuffle.FILES: 2336 self.shuffle_flag = 1 # Files shuffle 2337 elif shuffle == Shuffle.INFILE: 2338 self.shuffle_flag = 3 # Infile shuffle 2339 2340 def parse(self, children=None): 2341 raise NotImplementedError("Dataset has to implement parse method.") 2342 2343 @staticmethod 2344 def _find_files(patterns): 2345 """ 2346 Utility function to search for files with the given glob patterns. 2347 2348 Args: 2349 patterns (Union[str, list[str]]): String or list of patterns to be searched. 2350 2351 Returns: 2352 list, list of files. 2353 """ 2354 2355 if not isinstance(patterns, list): 2356 patterns = [patterns] 2357 2358 file_list = [] 2359 unmatched_patterns = [] 2360 for pattern in patterns: 2361 matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)] 2362 2363 if matches: 2364 file_list.extend(matches) 2365 else: 2366 unmatched_patterns.append(pattern) 2367 2368 if unmatched_patterns: 2369 raise ValueError("The following patterns did not match any files: {}.".format(unmatched_patterns)) 2370 2371 if file_list: # not empty 2372 return file_list 2373 raise ValueError("The list of path names matching the patterns is empty.") 2374 2375 def is_shuffled(self): 2376 return self.shuffle_flag > 0 2377 2378 def is_sharded(self): 2379 if self.num_shards is not None: 2380 return self.num_shards > 1 2381 return False 2382 2383 2384class MappableDataset(SourceDataset): 2385 """ 2386 Abstract class to represent a source dataset which supports use of samplers. 2387 """ 2388 2389 def parse(self, children=None): 2390 raise NotImplementedError("Dataset has to implement parse method.") 2391 2392 def __init__(self, num_parallel_workers=None, sampler=None, num_samples=None, shuffle=None, num_shards=None, 2393 shard_id=None, cache=None): 2394 num_shards, shard_id = self._update_data_shard(num_shards, shard_id) 2395 super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, 2396 num_shards=num_shards, shard_id=shard_id, cache=cache) 2397 self.shuffle_flag = replace_none(shuffle, True) 2398 self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) 2399 2400 def add_sampler(self, new_sampler): 2401 """ 2402 Add a child sampler for the current dataset. 2403 2404 Args: 2405 new_sampler (Sampler): The child sampler to be added. 2406 2407 Examples: 2408 >>> import mindspore.dataset as ds 2409 >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1") 2410 >>> 2411 >>> new_sampler = ds.DistributedSampler(10, 2) 2412 >>> dataset.add_sampler(new_sampler) 2413 """ 2414 # Note: By adding a sampler, the sampled IDs will flow to the new_sampler 2415 # after first passing through the current samplers attached to this dataset. 2416 self.dataset_size = None 2417 new_sampler.add_child(self.sampler) 2418 self.sampler = new_sampler 2419 2420 def use_sampler(self, new_sampler): 2421 """ 2422 Replace the last child sampler of the current dataset, remaining the parent sampler unchanged. 2423 2424 Args: 2425 new_sampler (Sampler): The new sampler to replace with. 2426 2427 Examples: 2428 >>> import mindspore.dataset as ds 2429 >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1") 2430 >>> 2431 >>> # use a DistributedSampler instead 2432 >>> new_sampler = ds.DistributedSampler(10, 2) 2433 >>> dataset.use_sampler(new_sampler) 2434 """ 2435 if new_sampler is None: 2436 raise TypeError("Input sampler can not be None.") 2437 if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): 2438 raise TypeError("Input sampler is not an instance of a sampler.") 2439 self.dataset_size = None 2440 2441 self.sampler = self.sampler.child_sampler 2442 self.add_sampler(new_sampler) 2443 2444 def is_shuffled(self): 2445 return self.sampler.is_shuffled() 2446 2447 def is_sharded(self): 2448 return self.sampler.is_sharded() 2449 2450 @check_split 2451 def split(self, sizes, randomize=True): 2452 """ 2453 Split the dataset into smaller, non-overlapping datasets. 2454 2455 Args: 2456 sizes (Union[list[int], list[float]]): If a list of integers [s1, s2, …, sn] is 2457 provided, the dataset will be split into n datasets of size s1, size s2, …, size sn 2458 respectively. If the sum of all sizes does not equal the original dataset size, an 2459 error will occur. 2460 If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1 2461 and must sum to 1, otherwise an error will occur. The dataset will be split into n 2462 Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the 2463 original dataset. 2464 If after rounding: 2465 2466 - Any size equals 0, an error will occur. 2467 - The sum of split sizes < K, the difference will be added to the first split. 2468 - The sum of split sizes > K, the difference will be removed from the first large 2469 enough split such that it will have at least 1 row after removing the difference. 2470 2471 randomize (bool, optional): Determines whether or not to split the data randomly. Default: ``True``. 2472 If ``True``, the data will be randomly split. Otherwise, each split will be created with 2473 consecutive rows from the dataset. 2474 2475 Note: 2476 1. There is an optimized split function, which will be called automatically when the dataset 2477 that calls this function is a MappableDataset. 2478 2. Dataset should not be sharded if split is going to be called. Instead, create a 2479 :class:`mindspore.dataset.DistributedSampler` and specify a split to shard after splitting. 2480 If the dataset is sharded after a split, it is strongly recommended setting the same 2481 seed in each instance of execution, otherwise each shard may not be part of the same 2482 split (see Examples). 2483 3. It is strongly recommended to not shuffle the dataset, but set `randomize` to ``True`` instead. 2484 Shuffling the dataset may not be deterministic, which means the data in each split 2485 will be different in each epoch. Furthermore, if sharding occurs after split, each 2486 shard may not be part of the same split. 2487 2488 Returns: 2489 Tuple[Dataset], a tuple of new datasets split from the original one. 2490 2491 Raises: 2492 RuntimeError: If get_dataset_size returns None or is not supported for this dataset. 2493 RuntimeError: If `sizes` is list of integers and sum of all elements in sizes does not 2494 equal the dataset size. 2495 RuntimeError: If `sizes` is list of float and there is a split with size 0 after calculations. 2496 RuntimeError: If the dataset is sharded prior to calling split. 2497 ValueError: If `sizes` is list of float and not all floats are between 0 and 1, or if the 2498 floats don't sum to 1. 2499 2500 Examples: 2501 >>> import mindspore.dataset as ds 2502 >>> # Since many datasets have shuffle on by default, set shuffle to False if split will be called! 2503 >>> image_folder_dataset_dir = "/path/to/image_folder_dataset_directory" 2504 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, shuffle=False) 2505 >>> 2506 >>> # Set the seed, and tell split to use this seed when randomizing. 2507 >>> # This is needed because sharding will be done later 2508 >>> ds.config.set_seed(58) 2509 >>> train_dataset, test_dataset = dataset.split([0.9, 0.1]) 2510 >>> 2511 >>> # To shard the train dataset, use a DistributedSampler 2512 >>> train_sampler = ds.DistributedSampler(10, 2) 2513 >>> train_dataset.use_sampler(train_sampler) 2514 """ 2515 if self.is_shuffled(): 2516 logger.warning("Dataset is shuffled before split.") 2517 2518 if self.is_sharded(): 2519 raise RuntimeError("Dataset should not be sharded before split.") 2520 2521 absolute_sizes = self._get_absolute_split_sizes(sizes) 2522 splits = [] 2523 current_split_start_index = 0 2524 for size in absolute_sizes: 2525 ds = copy.deepcopy(self) 2526 ds.dataset_size = None 2527 if randomize: 2528 # want to shuffle the same way every epoch before split, we are assuming 2529 # that the user will call set_seed 2530 random_sampler = samplers.RandomSampler() 2531 random_sampler.reshuffle_each_epoch = False 2532 ds.add_sampler(random_sampler) 2533 2534 subset_sampler = samplers.SequentialSampler(current_split_start_index, size) 2535 ds.add_sampler(subset_sampler) 2536 2537 # add sequential sampler, so that if user calls use_sampler, we will 2538 # get rid of the sequential sampler instead of something we need 2539 ds.add_sampler(samplers.SequentialSampler()) 2540 2541 splits.append(ds) 2542 2543 current_split_start_index += size 2544 2545 return tuple(splits) 2546 2547 2548class BucketBatchByLengthDataset(UnionBaseDataset): 2549 """ 2550 The result of applying BucketBatchByLength operation to the input dataset. 2551 """ 2552 2553 def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, 2554 pad_info, pad_to_bucket_boundary, drop_remainder): 2555 super().__init__(children=input_dataset) 2556 2557 self.column_names = to_list(column_names) 2558 self.bucket_boundaries = replace_none(bucket_boundaries, []) 2559 self.bucket_batch_sizes = replace_none(bucket_batch_sizes, []) 2560 self.element_length_function = element_length_function 2561 self.pad_info = replace_none(pad_info, {}) 2562 self.pad_to_bucket_boundary = replace_none(pad_to_bucket_boundary, False) 2563 self.drop_remainder = replace_none(drop_remainder, False) 2564 2565 def parse(self, children=None): 2566 return cde.BucketBatchByLengthNode(children[0], self.column_names, self.bucket_boundaries, 2567 self.bucket_batch_sizes, self.element_length_function, self.pad_info, 2568 self.pad_to_bucket_boundary, self.drop_remainder) 2569 2570 2571def _check_shm_usage(num_worker, queue_size, in_rowsize, out_rowsize): 2572 """ 2573 Check sufficient shared memory is available for shared memory queues 2574 when training in parallel mode. 2575 """ 2576 threshold_ratio = 0.8 2577 # Verify available size only when using static shared memory on Linux 2578 if platform.system().lower() not in {"windows", "darwin"} and in_rowsize != -1 and out_rowsize != -1: 2579 device_num = _get_device_num() 2580 # In the cluster, _get_device_num indicates the number of the entire cluster. The maximum number of cards 2581 # on the ascend server is 8. 2582 if device_num > 1: 2583 device_num = min(device_num, 8) 2584 shm_estimate_usage = device_num * num_worker * \ 2585 (queue_size + 2) * (in_rowsize + out_rowsize) * 1024 * 1024 2586 try: 2587 shm_available = psutil.disk_usage('/dev/shm').free 2588 if shm_estimate_usage >= threshold_ratio * shm_available: 2589 raise RuntimeError( 2590 "Insufficient shared memory available. Required: {}, Available: {}. " 2591 "The required memory can't exceed 80% of the available shared memory, " 2592 "it's recommended to reduce memory usage by following methods:\n" 2593 "1. reduce value of parameter max_rowsize or num_parallel_workers.\n" 2594 "2. reduce prefetch size by set_prefetch_size().\n" 2595 "3. disable shared memory by set_enable_shared_mem().".format(shm_estimate_usage, shm_available)) 2596 except FileNotFoundError: 2597 raise RuntimeError("Expected /dev/shm to exist.") 2598 2599 2600class BatchDataset(UnionBaseDataset): 2601 """ 2602 The result of applying Batch operation to the input dataset. 2603 2604 Args: 2605 input_dataset (Dataset): Input Dataset to be batched. 2606 batch_size (Union[int, function]): The number of rows each batch is created with. An 2607 int or callable which takes exactly 1 parameter, BatchInfo. 2608 drop_remainder (bool, optional): Determines whether or not to drop the last 2609 possibly incomplete batch. Default: ``False``. If True, and if there are less 2610 than batch_size rows available to make the last batch, then those rows will 2611 be dropped and not propagated to the child node. 2612 num_parallel_workers (int, optional): Number of workers to process the dataset in parallel. Default: ``None``. 2613 per_batch_map (callable, optional): Per batch map callable. A callable which takes 2614 (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch of 2615 Tensors on a given column. The number of lists should match with number of entries in input_columns. The 2616 last parameter of the callable must always be a BatchInfo object. 2617 input_columns (Union[str, list[str]], optional): List of names of the input columns. The size of the list must 2618 match with signature of per_batch_map callable. 2619 output_columns (Union[str, list[str]], optional): List of names assigned to the columns outputted by 2620 the last operation. This parameter is mandatory if len(input_columns) != 2621 len(output_columns). The size of this list must match the number of output 2622 columns of the last operation. Default: ``None``, output columns will have the same 2623 name as the input columns, i.e., the columns will be replaced. 2624 max_rowsize(Union[int, list[int]], optional): Maximum size of row in MB that is used for shared memory 2625 allocation to copy data between processes, the total occupied shared memory will increase as 2626 ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set to -1, 2627 shared memory will be dynamically allocated with the actual size of data. This is only used if 2628 ``python_multiprocessing`` is set to True. If it is an int value, it represents 2629 ``input_columns`` and ``output_columns`` use this value as the unit to create shared memory. 2630 If it is a list, the first element represents the ``input_columns`` use this value as the unit to 2631 create shared memory, and the second element represents ``output_columns`` use this value as the unit 2632 to create shared memory. Default: 16. 2633 2634 """ 2635 2636 def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, 2637 input_columns=None, output_columns=None, python_multiprocessing=False, max_rowsize=16): 2638 super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers) 2639 2640 if BatchDataset._is_ancestor_of_repeat(input_dataset): 2641 logger.warning("Repeat is located before batch, data from two epochs can be batched together.") 2642 2643 BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) 2644 2645 # if batch_size is callable, set batch_size to 1 and batch_size_func to that callable function 2646 self.batch_size = batch_size if not callable(batch_size) else 1 2647 self.batch_size_func = None if not callable(batch_size) else batch_size 2648 2649 self.drop_remainder = replace_none(drop_remainder, False) 2650 2651 self.per_batch_map = per_batch_map 2652 2653 self.input_columns = to_list(input_columns) 2654 self.output_columns = to_list(output_columns) 2655 2656 self.python_multiprocessing = python_multiprocessing 2657 self.process_pool = None 2658 if isinstance(max_rowsize, int): 2659 self.max_rowsize = [max_rowsize * self.batch_size] * 2 if max_rowsize != -1 else [max_rowsize, max_rowsize] 2660 else: 2661 self.max_rowsize = [max_rowsize[0] * self.batch_size, max_rowsize[1] * self.batch_size] 2662 2663 def __del__(self): 2664 if hasattr(self, "process_pool") and self.process_pool is not None: 2665 self.process_pool.terminate() 2666 del self.process_pool 2667 2668 def parse(self, children=None): 2669 return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, False, self.input_columns, 2670 self.output_columns, self.batch_size_func, self.per_batch_map, {}, 2671 self.process_pool) 2672 2673 @staticmethod 2674 def _is_ancestor_of_repeat(dataset): 2675 """ 2676 Utility function to find the case where repeat is used before batch. 2677 2678 Args: 2679 dataset (Dataset): Dataset to be checked. 2680 2681 Returns: 2682 bool, whether repeat is used before batch. 2683 """ 2684 if isinstance(dataset, RepeatDataset): 2685 return True 2686 flag = False 2687 for input_dataset in dataset.children: 2688 flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset) 2689 return flag 2690 2691 @staticmethod 2692 def _update_batch_size_for_syncwait(dataset, batch_size): 2693 """ 2694 Utility function to notify batch size to sync_wait. 2695 2696 Args: 2697 dataset (Dataset): Dataset to be checked. 2698 batch_size (int): batch size to notify. 2699 """ 2700 if isinstance(dataset, SyncWaitDataset): 2701 dataset.update_sync_batch_size(batch_size) 2702 for input_dataset in dataset.children: 2703 BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) 2704 2705 def __deepcopy__(self, memodict): 2706 return self.__safe_deepcopy__(memodict, exclude=("per_batch_map", "batch_size_func", "__transfer_dataset__")) 2707 2708 # Iterator bootstrap will be called on iterator construction. 2709 # A deep copy of Dataset object is created prior of iterator_bootstrap. 2710 # This method will create per iterator process pool and bind pyfunc execution to the pool. 2711 def iterator_bootstrap(self): 2712 """ 2713 Per iterator bootstrap callback. 2714 """ 2715 if self.python_multiprocessing and platform.system().lower() == 'windows': 2716 logger.warning("Python multiprocessing is not supported on Windows platform.") 2717 if self.python_multiprocessing and get_debug_mode(): 2718 logger.warning("Python multiprocessing is not supported in debug mode." 2719 " Ignoring Python multiprocessing for batch operation.") 2720 self.python_multiprocessing = False 2721 if self.python_multiprocessing and platform.system().lower() != 'windows': 2722 if self.per_batch_map is None: 2723 logger.warning("per_batch_map is None so python_multiprocessing is ignored for batch.") 2724 return 2725 2726 # If user didn't specify num_parallel_workers, set it to default 2727 if self.num_parallel_workers is None: 2728 self.num_parallel_workers = get_num_parallel_workers() 2729 2730 self.process_pool = _PythonMultiprocessing(str(self), self.num_parallel_workers, [self.per_batch_map], 2731 self.max_rowsize) 2732 # Wrap per_batch_map into _PythonCallable 2733 self.per_batch_map = _PythonCallable(self.per_batch_map, 0, self.process_pool) 2734 else: 2735 if self.per_batch_map is not None: 2736 self.per_batch_map = FuncWrapper(self.per_batch_map) 2737 2738 2739class BatchInfo(cde.CBatchInfo): 2740 """ 2741 This class helps to get dataset information dynamically when the input of `batch_size` or `per_batch_map` 2742 in `batch` operation is a callable object. 2743 """ 2744 2745 def get_batch_num(self): 2746 """ 2747 Return the batch number being processed in current epoch, start from 0. 2748 2749 Examples: 2750 >>> # Create a dataset where its batch size is dynamic 2751 >>> # Define a callable batch size function and let batch size increase 1 each time. 2752 >>> import mindspore.dataset as ds 2753 >>> from mindspore.dataset import BatchInfo 2754 >>> 2755 >>> dataset = ds.GeneratorDataset([i for i in range(3)], "column1", shuffle=False) 2756 >>> def add_one(BatchInfo): 2757 ... return BatchInfo.get_batch_num() + 1 2758 >>> dataset = dataset.batch(batch_size=add_one) 2759 >>> print(list(dataset)) 2760 [[Tensor(shape=[1], dtype=Int64, value= [0])], [Tensor(shape=[2], dtype=Int64, value= [1, 2])]] 2761 """ 2762 return 2763 2764 def get_epoch_num(self): 2765 """ 2766 Return the epoch number, start from 0. 2767 2768 Examples: 2769 >>> # Create a dataset where its batch size is dynamic 2770 >>> # Define a callable batch size function and let batch size increase 1 each epoch. 2771 >>> import mindspore.dataset as ds 2772 >>> from mindspore.dataset import BatchInfo 2773 >>> 2774 >>> dataset = ds.GeneratorDataset([i for i in range(4)], "column1", shuffle=False) 2775 >>> def add_one_by_epoch(BatchInfo): 2776 ... return BatchInfo.get_epoch_num() + 1 2777 >>> dataset = dataset.batch(batch_size=add_one_by_epoch) 2778 >>> 2779 >>> result = [] 2780 >>> epoch = 2 2781 >>> iterator = dataset.create_tuple_iterator(num_epochs=epoch) 2782 >>> for i in range(epoch): 2783 ... result.extend(list(iterator)) 2784 >>> # result: 2785 >>> # [[Tensor(shape=[1], dtype=Int64, value= [0])], [Tensor(shape=[1], dtype=Int64, value= [1])], 2786 >>> # [Tensor(shape=[1], dtype=Int64, value= [2])], [Tensor(shape=[1], dtype=Int64, value= [3])], 2787 >>> # [Tensor(shape=[2], dtype=Int64, value= [0, 1])], [Tensor(shape=[2], dtype=Int64, value= [2, 3])]] 2788 """ 2789 return 2790 2791 2792class BlockReleasePair: 2793 """ 2794 The blocking condition class used by SyncWaitDataset. 2795 2796 Args: 2797 init_release_rows (int): Number of lines to allow through the pipeline. 2798 callback (function): The callback function that will be called when release is called. Default: ``None``. 2799 """ 2800 2801 def __init__(self, init_release_rows, callback=None): 2802 if isinstance(init_release_rows, int) and init_release_rows <= 0: 2803 raise ValueError("release_rows need to be greater than 0.") 2804 self.row_count = -init_release_rows 2805 self.cv = threading.Condition() 2806 self.callback = callback 2807 self.default_rows = init_release_rows 2808 self.disable = False 2809 2810 def __deepcopy__(self, memodict): 2811 return self 2812 2813 def reset(self): 2814 with self.cv: 2815 self.row_count = -self.default_rows 2816 self.cv.notify_all() 2817 2818 def update_batched_size(self, batch_size): 2819 # sanity check 2820 if isinstance(batch_size, int) and batch_size <= 0: 2821 raise ValueError("batch_size need to be greater than 0.") 2822 2823 # should only use before the pipeline creates 2824 self.row_count *= batch_size 2825 self.default_rows *= batch_size 2826 2827 def block_func(self): 2828 """ 2829 Function for handing blocking condition. 2830 2831 Returns: 2832 bool, True. 2833 """ 2834 with self.cv: 2835 # if disable is true, the always evaluate to true 2836 not_time_out = self.cv.wait_for(lambda: (self.row_count < 0 or self.disable), 2837 timeout=get_callback_timeout()) 2838 # time_out will be False if time out occurs 2839 if not not_time_out: 2840 logger.warning("Timeout happened in sync_wait, maybe dataset.sync_update(condition=...) " 2841 "is not added after dataset.create_dict_iterator(...), now disabling lock.") 2842 self.disable = True 2843 self.row_count += 1 2844 return True 2845 2846 def release_func(self, pass_rows=None, data=None): 2847 with self.cv: 2848 if pass_rows is None: 2849 pass_rows = self.default_rows 2850 self.row_count -= pass_rows 2851 if self.callback is not None: 2852 self.callback(data) 2853 self.cv.notify_all() 2854 2855 def disable_lock(self): 2856 with self.cv: 2857 self.disable = True 2858 self.cv.notify_all() 2859 2860 2861class PaddedBatchDataset(UnionBaseDataset): 2862 """ 2863 The result of applying Batch operation to the input dataset. 2864 2865 Args: 2866 input_dataset (Dataset): Input Dataset to be batched. 2867 batch_size (Union[int, function]): The number of rows each batch is created with. An 2868 int or callable which takes exactly 1 parameter, BatchInfo. 2869 drop_remainder (bool, optional): Determines whether or not to drop the last 2870 possibly incomplete batch. Default: ``False``. If True, and if there are less 2871 than batch_size rows available to make the last batch, then those rows will 2872 be dropped and not propagated to the child node. 2873 num_parallel_workers (int, optional): Number of workers to process the dataset in parallel. Default: ``None``. 2874 pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)} 2875 will pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0. 2876 """ 2877 2878 def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None, pad_info=None): 2879 super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers) 2880 2881 if PaddedBatchDataset._is_ancestor_of_repeat(input_dataset): 2882 logger.warning("Repeat is located before padded_batch, data from two epochs can be batched together.") 2883 2884 PaddedBatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) 2885 2886 # if batch_size is callable, set batch_size to 1 and batch_size_func to that callable function 2887 self.batch_size = batch_size if not callable(batch_size) else 1 2888 self.batch_size_func = None if not callable(batch_size) else batch_size 2889 2890 self.drop_remainder = replace_none(drop_remainder, False) 2891 2892 self.pad = bool(pad_info is not None) 2893 self.pad_info = replace_none(pad_info, dict()) 2894 2895 def parse(self, children=None): 2896 return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, self.pad, [], 2897 [], self.batch_size_func, None, self.pad_info, None) 2898 2899 @staticmethod 2900 def _is_ancestor_of_repeat(dataset): 2901 """ 2902 Utility function to find the case where repeat is used before batch. 2903 2904 Args: 2905 dataset (Dataset): Dataset to be checked. 2906 2907 Returns: 2908 bool, whether repeat is used before batch. 2909 """ 2910 if isinstance(dataset, RepeatDataset): 2911 return True 2912 flag = False 2913 for input_dataset in dataset.children: 2914 flag = flag | PaddedBatchDataset._is_ancestor_of_repeat(input_dataset) 2915 return flag 2916 2917 @staticmethod 2918 def _update_batch_size_for_syncwait(dataset, batch_size): 2919 """ 2920 Utility function to notify batch size to sync_wait. 2921 2922 Args: 2923 dataset (Dataset): Dataset to be checked. 2924 batch_size (int): batch size to notify. 2925 """ 2926 if isinstance(dataset, SyncWaitDataset): 2927 dataset.update_sync_batch_size(batch_size) 2928 for input_dataset in dataset.children: 2929 PaddedBatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) 2930 2931 def __deepcopy__(self, memodict): 2932 return self.__safe_deepcopy__(memodict, exclude=("batch_size_func", "__transfer_dataset__")) 2933 2934 2935class SyncWaitDataset(UnionBaseDataset): 2936 """ 2937 The result of adding a blocking condition to the input Dataset. 2938 2939 Args: 2940 input_dataset (Dataset): Input dataset to apply flow control. 2941 num_batch (int): Number of batches without blocking at the start of each epoch. 2942 condition_name (str): Condition name that is used to toggle sending next row. 2943 callback (function): Callback function that will be invoked when sync_update is called. Default: ``None``. 2944 2945 Raises: 2946 RuntimeError: If condition name already exists. 2947 """ 2948 2949 def __init__(self, input_dataset, condition_name, num_batch, callback=None): 2950 super().__init__(children=input_dataset) 2951 2952 # set to the default value, waiting for the batch to update it 2953 self._condition_name = condition_name 2954 if isinstance(num_batch, int) and num_batch <= 0: 2955 raise ValueError("num_batch need to be greater than 0.") 2956 2957 self._pair = BlockReleasePair(num_batch, callback) 2958 if self._condition_name in self.children[0].get_sync_notifiers(): 2959 raise RuntimeError("Condition name is already in use.") 2960 logger.info("Please remember to add dataset.sync_update(condition=%s), otherwise hanging will result. " 2961 "If dataset.sync_update(condition=%s) has already been added, you can ignore the info.", 2962 condition_name, condition_name) 2963 2964 def parse(self, children=None): 2965 return cde.SyncWaitNode(children[0], self._condition_name, self._pair.block_func) 2966 2967 def get_sync_notifiers(self): 2968 return {**self.children[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}} 2969 2970 def is_sync(self): 2971 return True 2972 2973 def update_sync_batch_size(self, batch_size): 2974 if isinstance(batch_size, int) and batch_size <= 0: 2975 raise ValueError("num_batch need to be greater than 0.") 2976 self._pair.update_batched_size(batch_size) 2977 2978 def disable_sync(self): 2979 logger.info("Disabling Sync") 2980 self._pair.disable_lock() 2981 2982 @staticmethod 2983 def _is_ancestor_of_batch(dataset): 2984 """ 2985 Utility function to find the case where sync_wait is used before batch. 2986 2987 Args: 2988 dataset (Dataset): Dataset to be checked. 2989 2990 Returns: 2991 bool, whether sync_wait is used before batch. 2992 """ 2993 if isinstance(dataset, (BatchDataset, PaddedBatchDataset)): 2994 return True 2995 flag = False 2996 for input_dataset in dataset.children: 2997 flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset) 2998 return flag 2999 3000 def iterator_bootstrap(self): 3001 self._pair.reset() 3002 3003 3004class ShuffleDataset(UnionBaseDataset): 3005 """ 3006 The result of applying Shuffle operation to the input Dataset. 3007 3008 Args: 3009 input_dataset (Dataset): Input Dataset to be shuffled. 3010 buffer_size (int): Size of the buffer. 3011 3012 Raises: 3013 RuntimeError: If exist sync operations before shuffle. 3014 """ 3015 3016 def __init__(self, input_dataset, buffer_size): 3017 super().__init__(children=input_dataset) 3018 self.buffer_size = buffer_size 3019 self.reshuffle_each_epoch = True 3020 3021 if self.is_sync(): 3022 raise RuntimeError("No shuffle after sync operators.") 3023 3024 def parse(self, children=None): 3025 return cde.ShuffleNode(children[0], self.buffer_size, self.reshuffle_each_epoch) 3026 3027 def is_shuffled(self): 3028 return True 3029 3030 3031# Pyfunc collection for multiprocess pyfunc 3032# This global variable will only be used within subprocesses 3033_OP_NAME = dict() 3034_OP_PROCESS = dict() 3035 3036 3037# PythonCallable wrapper for multiprocess pyfunc 3038class _PythonCallable: 3039 """ 3040 Internal Python function wrapper for multiprocessing pyfunc. 3041 """ 3042 3043 def __init__(self, py_callable, idx, pool=None): 3044 # Original Python callable from user. 3045 self.py_callable = py_callable 3046 # Process pool created for current iterator. 3047 self.pool = pool 3048 # Python callable index 3049 self.idx = idx 3050 3051 def __call__(self, *args): 3052 result = None 3053 get_data_from_worker_process = False 3054 while get_data_from_worker_process is False: 3055 if self.pool.is_running() and check_iterator_cleanup() is False: 3056 try: 3057 result = self.pool.execute(self.idx, *args) 3058 except multiprocessing.TimeoutError: 3059 continue 3060 get_data_from_worker_process = True 3061 else: 3062 # worker process is stopped 3063 logger.info("The worker process of map operation is stopped. " 3064 "So return None to main thread and break the main thread.") 3065 return None 3066 # got value from worker process 3067 if not isinstance(result, tuple) and get_data_from_worker_process is True: 3068 result = (result,) 3069 return result 3070 3071 def to_json(self): 3072 return self.py_callable.to_json() 3073 3074 3075# used when python_multiprocessing=True in map 3076class Pipe: 3077 """ 3078 Class to handle communication between the master process and the worker processes. 3079 """ 3080 3081 def __init__(self, warning_ctl, shared_memory=False, max_rowsize=16): 3082 self.shared_memory = shared_memory 3083 self.eof = multiprocessing.Event() 3084 if self.shared_memory: 3085 self.in_queue = _SharedQueue(1, warning_ctl, max_rowsize=max_rowsize[0]) 3086 self.res_queue = _SharedQueue(1, warning_ctl, max_rowsize=max_rowsize[1]) 3087 else: 3088 self.in_queue = _Queue(1) 3089 self.res_queue = _Queue(1) 3090 self.in_queue.cancel_join_thread() # Ensure that the process does not hung when exiting 3091 3092 def master_send(self, func_index, data): 3093 self.in_queue.put_nowait((func_index, *data)) 3094 3095 def master_receive(self): 3096 if self.eof is None: 3097 raise RuntimeError("EOF is none when get data from worker.") 3098 if self.eof.is_set(): 3099 return None 3100 return self.res_queue.get(timeout=1) 3101 3102 def master_close(self): 3103 self.eof.set() 3104 self.send_finish_signal_to_worker() 3105 self.send_finish_signal() 3106 3107 def send_finish_signal(self): 3108 self.worker_send(None) 3109 3110 def send_finish_signal_to_worker(self): 3111 self.master_send(0, "QUIT") 3112 3113 def worker_send(self, data): 3114 self.res_queue.put_until(data, timeout=1, exit_signal=self.eof) 3115 3116 def worker_receive(self): 3117 result = self.in_queue.get_until(timeout=1, exit_signal=self.eof) 3118 if result is None: 3119 return result 3120 if len(result) == 1: 3121 raise RuntimeError(f"Corrupted data. Worker received {len(result)} elements, it should be more than 1.") 3122 func_index, *data = result 3123 return func_index, tuple(data) 3124 3125 3126def _main_process_already_exit(): 3127 """ 3128 Judge whether main process already exit. 3129 """ 3130 ppid = os.getppid() 3131 3132 if (platform.system().lower() != 'windows' and 3133 not _PythonMultiprocessing.is_process_alive(ppid)): 3134 return True 3135 return False 3136 3137 3138def _worker_loop(operations, pipe, worker_id): 3139 """ 3140 Multiprocess worker process loop. 3141 """ 3142 # Ensure that the process does not hung when exiting 3143 pipe.res_queue.cancel_join_thread() 3144 3145 def _ignore_sigint(): 3146 """ 3147 We need to ignore sigint signal here so subprocesses can exit normally and clear. 3148 """ 3149 signal.signal(signal.SIGINT, signal.SIG_IGN) 3150 3151 # If the default random seed has not been changed, there is no need to fix the randomness. 3152 # Otherwise, set the random seed for each child process to "base_seed + worker_id" to ensure 3153 # that the random results of each process are different. 3154 if get_seed() != 5489: 3155 set_seed(get_seed() + worker_id) 3156 while not _main_process_already_exit(): 3157 _ignore_sigint() 3158 3159 result = pipe.worker_receive() 3160 if result is None: 3161 return 3162 (idx, input_tensors) = result 3163 if input_tensors == "QUIT": 3164 break 3165 try: 3166 output_tensors = operations[idx](*input_tensors) 3167 3168 pipe.worker_send(output_tensors) 3169 except Exception: 3170 pipe.worker_send(ExceptionHandler(where="in map(or batch) worker and execute Python function")) 3171 # Do not return 3172 3173 # release the queue when stop the worker by master 3174 del pipe.in_queue 3175 del pipe.res_queue 3176 3177 3178def worker_target(operations, worker_id): 3179 return lambda pipe: _worker_loop(operations, pipe, worker_id) 3180 3181 3182class _MPWorker(multiprocessing.Process): 3183 """ 3184 Worker process for multiprocessing. 3185 """ 3186 3187 def __init__(self, operations, warning_ctl, max_rowsize=16, worker_id=0): 3188 shared_memory = get_enable_shared_mem() 3189 self.pipe = Pipe(warning_ctl, shared_memory=shared_memory, max_rowsize=max_rowsize) 3190 self.check_interval = get_multiprocessing_timeout_interval() 3191 super().__init__(target=worker_target(operations, worker_id), name="MapWorker" + str(worker_id), 3192 args=(self.pipe,), daemon=True) 3193 3194 def execute(self, idx, *args): 3195 """Acquiring data from a worker in an infinite loop""" 3196 self.pipe.master_send(idx, args) 3197 time_s = time.time() 3198 wait_count = 1 3199 while True: 3200 cost_time = time.time() - time_s 3201 if cost_time / self.check_interval >= wait_count: 3202 wait_count += 1 3203 logger.warning("It has been waiting for " + "%.3f" % cost_time + "s because the sub-process " 3204 "worker of the map operation is hanging. " 3205 "Check whether the user defined data transform is too slow or the " 3206 "output data is too large. You can also set the timeout interval by " 3207 "ds.config.set_multiprocessing_timeout_interval to adjust the output frequency " 3208 "of this log.") 3209 pid = self.pid 3210 logger.warning("Map worker subprocess ID {} is stuck.".format(pid)) 3211 install_status, _ = subprocess.getstatusoutput("py-spy --version") 3212 if install_status == 0: 3213 stack = subprocess.getoutput("py-spy dump -p {} -l".format(pid)) 3214 logger.warning("Map worker subprocess stack:\n{}".format(stack)) 3215 else: 3216 logger.warning("Please `pip install py-spy` to get the stacks of the stuck process.") 3217 try: 3218 res = self.pipe.master_receive() 3219 # Because there is no need to copy when creating Tensors in the C++layer, it reduces the time 3220 # from np.ndarray to C++Tensor creation. However, when using shared memory in multiple processes, 3221 # the address of the shared memory will always be passed to subsequent nodes in the dataset pipeline, 3222 # and the shared memory will also be written by the current node, causing dirty data to be accessed 3223 # by subsequent nodes in the pipeline. So make a memory copy here to solve the problem of 3224 # shared memory being contaminated. 3225 if get_enable_shared_mem(): 3226 res = copy.deepcopy(res) 3227 except queue.Empty: 3228 continue 3229 if res is None: 3230 # receive finish signal 3231 return None 3232 if isinstance(res, ExceptionHandler): 3233 res.reraise() 3234 return res 3235 3236 def close(self): 3237 try: 3238 if self.is_alive(): 3239 # release the eager executor which is used by current process 3240 transforms.transforms.clean_unused_executors() 3241 3242 logger.info(f"Closing worker with PID: {self.pid}") 3243 self.pipe.master_close() 3244 # del the handle which hold by master 3245 del self.pipe.in_queue 3246 del self.pipe.res_queue 3247 super().terminate() 3248 super().join() 3249 super().close() 3250 3251 except ValueError: 3252 # Process has been closed already 3253 return 3254 return 3255 3256 def is_alive(self): 3257 try: 3258 return super().is_alive() 3259 except ValueError: 3260 return False 3261 3262 3263class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime): 3264 """ 3265 A wrapper to multiprocessing.pool that performs cleanup and ensure proper termination of forked processes. 3266 """ 3267 3268 class _ExceptHookHandler: 3269 """ 3270 Internal class ExceptionHandler 3271 """ 3272 3273 def __init__(self): 3274 self.origin_hook = sys.excepthook 3275 sys.excepthook = self.__handler_exception 3276 3277 @staticmethod 3278 def mp_pool_exit_preprocess(): 3279 if check_iterator_cleanup() is False: 3280 # Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async 3281 # applied to the multiprocessing task to prevent multiprocessing from hang when exiting 3282 _set_iterator_cleanup() 3283 time.sleep(3) 3284 3285 def __handler_exception(self, ex_type, value, tb): 3286 self.origin_hook(ex_type, value, tb) 3287 self.mp_pool_exit_preprocess() 3288 3289 def __init__(self, op_name, num_parallel_workers, operations, max_rowsize=16): 3290 super(_PythonMultiprocessing, self).__init__() 3291 self.op_name = op_name 3292 self.num_parallel_workers = num_parallel_workers 3293 self.operations = operations 3294 self.max_rowsize = max_rowsize 3295 3296 self.workers = None 3297 self.pids = None 3298 self.op_id = -1 3299 3300 self.queues_map = {} 3301 self.next_queue = 0 3302 3303 self.eot = None 3304 self.watch_dog = None 3305 self.ppid = os.getpid() 3306 self.hook = None 3307 self.warning_ctl = None 3308 # cache thread (get_ident()) to worker_id mapping in Python layer 3309 self.python_threads_to_workers = {} 3310 self.eof = None 3311 3312 def __del__(self): 3313 try: 3314 self.terminate() 3315 except TypeError: 3316 pass 3317 3318 # This wait function is for cleaning zombie subprocesses 3319 @staticmethod 3320 def wait_pid(): 3321 """ 3322 This function is used by the main process to release subprocess resources. 3323 """ 3324 try: 3325 while True: 3326 child_pid, _ = os.waitpid(-1, os.WNOHANG) 3327 if child_pid == 0: 3328 break 3329 except OSError: 3330 # waitpid may be failed for some reasons so we ignore this error 3331 pass 3332 3333 # Dataset need watch_dog thread to monitoring fork multi-processing, 3334 # and thread can't be a member function otherwise python won't collect and release resources. 3335 @staticmethod 3336 def _watch_dog(eot, workers): 3337 """ 3338 This thread is for monitoring subprocesses forked by GeneratorDataset/map/batch 3339 """ 3340 if not isinstance(workers, list): 3341 raise TypeError("[Internal Error] The 2nd parameter of watch dog thread should be list of process, " 3342 "but got {}.".format(type(workers))) 3343 3344 while not eot.is_set(): 3345 # Monitoring and count how many subprocesses already exit 3346 clear_subprocess_timeout = _PythonMultiprocessing._monitor_subprocess_exit(workers) 3347 # If find subprocess exit, we will wait for 30s and do some waitpid operations 3348 if clear_subprocess_timeout > 0: 3349 start = time.time() 3350 while time.time() - start < clear_subprocess_timeout: 3351 # We need to distinguishing get_dataset_size or train finished normally and hang scenario. 3352 # If get_dataset_size or train finished normally, _stop_subprocess can be execute and 3353 # self.need_abort can be set to True. If main process is hang in get(), self.need_abort 3354 # will never set to True, then we wait for 30s and kill main process 3355 if eot.is_set(): 3356 return 3357 # Sometimes subprocess may be zombie, so in 30s we can wait and do some useful tasks(waitpid). 3358 _PythonMultiprocessing.wait_pid() 3359 # multiprocessing.queue may hang in .get() forever when put() process was killed. 3360 # We have to exit main process otherwise main process will hang. 3361 _PythonMultiprocessing._terminate_processes(workers) 3362 logger.critical("The subprocess of dataset may exit unexpected or be killed, " 3363 "main process will exit. If this is not an artificial operation, you can use " 3364 "ds.config.set_enable_watchdog(False) to block this error.") 3365 os.kill(os.getpid(), signal.SIGTERM) 3366 3367 # release the workers 3368 del workers 3369 3370 @staticmethod 3371 def _terminate_processes(processes): 3372 """Terminate subprocesses""" 3373 3374 for p in processes: 3375 try: 3376 if p.exitcode is None: 3377 p.terminate() 3378 except Exception: # pylint: disable=broad-except 3379 # process has been closed already 3380 pass 3381 for p in processes: 3382 if p._closed is False: # pylint: disable=W0212 3383 # We don't use w.join because join can only used in main process or join will raise an error. 3384 p._popen.wait() # pylint: disable=W0212 3385 3386 # Monitor the exit number of subprocesses 3387 @staticmethod 3388 def _monitor_subprocess_exit(workers): 3389 """ 3390 To monitor whether process is exit. 3391 3392 Args: 3393 workers (list of multiprocessing.Process): multiprocessing.Process. 3394 3395 Returns: 3396 int, the timeout(in seconds) when process exit. 3397 """ 3398 for w in workers: 3399 try: 3400 exit_code = w.exitcode 3401 if exit_code is not None: 3402 # For kill -9, we can exit quickly 3403 if exit_code == -9: 3404 return 1 3405 # For kill -15, we still exit after 30s 3406 if exit_code == -15: 3407 return 30 3408 # In some cases the subprocess has been killed but the exitcode is still None. 3409 # So we use os.kill(pid, 0) to check if it is alive. 3410 subprocess_alive = _PythonMultiprocessing.is_process_alive(w.pid) 3411 if not subprocess_alive: 3412 # Like kill -15, we wait 30s before exit 3413 return 30 3414 except ValueError: 3415 # process has been closed already 3416 return 0 3417 return 0 3418 3419 @staticmethod 3420 def is_process_alive(pid): 3421 """ 3422 Check if the process is alive or not. 3423 Note: We hit a deadlock when we use psutil or w.exitcode to check whether a process is alive. 3424 Instead we use os.kill(ppid, 0). 3425 3426 Args: 3427 pid: pid of the process to be checked 3428 3429 Returns: 3430 True if the process is alive 3431 """ 3432 3433 try: 3434 os.kill(pid, 0) 3435 except OSError: 3436 return False 3437 return True 3438 3439 # When main process exit, subprocesses will be terminate 3440 @staticmethod 3441 def _clean_process(ppid, workers, quit_signal): 3442 """ 3443 This is the execute function of clean process, if we found main process exited, we will clean subprocesses. 3444 3445 Args: 3446 ppid: The process id of main process. 3447 workers: The list of subprocesses. 3448 quit_signal: The flag of quit. 3449 """ 3450 signal.signal(signal.SIGINT, signal.SIG_IGN) 3451 while _PythonMultiprocessing.is_process_alive(ppid): 3452 if quit_signal.is_set(): 3453 return 3454 time.sleep(0.1) 3455 3456 _PythonMultiprocessing._terminate_processes(workers) 3457 del workers 3458 os.kill(os.getpid(), signal.SIGTERM) 3459 3460 def launch(self, op_id=-1): 3461 """ 3462 Launch Python multiprocessing pool. 3463 3464 Args: 3465 pop_id: ID for operation to have Python multiprocessing pool launched 3466 3467 Returns: 3468 Python multiprocssing pool is launched. 3469 """ 3470 self.python_threads_to_workers = {} 3471 self.op_id = op_id 3472 logger.info("Launching new Python Multiprocessing pool for Op:" + str(self.op_id)) 3473 if self.is_mp_enabled(): 3474 message = "Launching a new Python multiprocessing pool while a pool already exists!" + \ 3475 " The existing pool will be terminated first." 3476 logger.warning(message) 3477 self.terminate() 3478 self.reset() 3479 self.create_pool() 3480 3481 def create_pool(self): 3482 """ 3483 3484 Returns: 3485 3486 """ 3487 if get_enable_shared_mem(): 3488 _check_shm_usage(self.num_parallel_workers, 1, self.max_rowsize[0], self.max_rowsize[1]) 3489 3490 if self.workers is not None: 3491 raise Exception("Pool was already created, close it first.") 3492 3493 # Let gc collect unreferenced memory to avoid child processes in the pool to do it 3494 gc.collect() 3495 3496 # Construct python worker processes 3497 self.workers = [] 3498 self.warning_ctl = multiprocessing.Value('i', 0) 3499 for worker_id in range(self.num_parallel_workers): 3500 worker = _MPWorker(self.operations, self.warning_ctl, self.max_rowsize, worker_id) 3501 worker.start() 3502 self.workers.append(worker) 3503 3504 logger.info("Op: " + str(self.op_id) + " Python multiprocessing pool workers' PIDs: " + str(self.get_pids())) 3505 3506 self.hook = _PythonMultiprocessing._ExceptHookHandler() 3507 3508 # The op (Map, Batch, etc) multiprocessing will launch a watch dog thread for monitoring sub processes 3509 self._launch_watch_dog() 3510 3511 atexit.register(self.terminate) 3512 3513 def terminate(self): 3514 # close watch dog first and then close all the workers 3515 self.abort_watchdog() 3516 self.close_all_workers() 3517 if hasattr(self, "warning_ctl"): 3518 del self.warning_ctl 3519 3520 def get_pids(self): 3521 """ 3522 Get list of worker's PIDs 3523 3524 Returns: 3525 list of strings 3526 """ 3527 if not self.is_mp_enabled(): 3528 return [] 3529 if not self.pids: 3530 self.pids = [] 3531 if self.workers: 3532 for w in self.workers: 3533 try: 3534 self.pids.append(w.pid) 3535 except ValueError: 3536 continue 3537 return self.pids 3538 3539 def add_new_workers(self, num_new_workers): 3540 logger.info( 3541 "Increasing num_parallel_workers of Python Multiprocessing pool for Op:" + str(self.op_id) + 3542 ", old num_workers=" + str(self.num_parallel_workers) + " new num_workers=" + str( 3543 self.num_parallel_workers + 3544 num_new_workers) + ".") 3545 self.terminate() 3546 self.num_parallel_workers += num_new_workers 3547 self.launch(self.op_id) 3548 3549 def remove_workers(self, num_removed_workers): 3550 logger.info( 3551 "Decreasing num_parallel_workers of Python Multiprocessing pool for Op:" + str(self.op_id) + 3552 ", old num_workers=" + str(self.num_parallel_workers) + " new num_workers=" + str( 3553 self.num_parallel_workers - 3554 num_removed_workers) + ".") 3555 self.terminate() 3556 self.num_parallel_workers -= num_removed_workers 3557 self.launch(self.op_id) 3558 3559 def is_mp_enabled(self): 3560 return self.workers is not None 3561 3562 def execute(self, idx, *args): 3563 """ 3564 Execute 3565 """ 3566 t_id = threading.get_ident() 3567 # get the worker_id from Python layer cache first, get from Cpp layer if not found. 3568 worker_id = self.python_threads_to_workers.setdefault(t_id, self.get_thread_to_worker()) 3569 if worker_id >= len(self.workers): 3570 raise RuntimeError("[Internal] worker_id value is greater than number of available workers!") 3571 3572 # todo check_iterator_cleanup 3573 if self.is_running() and check_iterator_cleanup() is False: 3574 return self.workers[worker_id].execute(idx, *args) 3575 3576 return None 3577 3578 def _launch_watch_dog(self): 3579 """ 3580 We will launch a watchdog thread and a clean process to cleaning subprocess when there is process was killed. 3581 The watchdog thread will cleanup subprocesses and main process when one of the subprocesses was killed. 3582 The cleaning subprocess will cleanup subprocesses when main process was killed. 3583 """ 3584 if platform.system().lower() != 'windows': 3585 self.eof = multiprocessing.Event() 3586 self.cleaning_process = multiprocessing.Process(target=self._clean_process, 3587 name="MapCleanProcess", 3588 args=(self.ppid, self.workers, self.eof), 3589 daemon=True) 3590 self.cleaning_process.start() 3591 3592 if get_enable_watchdog(): 3593 self.eot = threading.Event() 3594 self.watch_dog = threading.Thread(target=self._watch_dog, 3595 name="MapWatchDog", 3596 args=(self.eot, self.workers + [self.cleaning_process]), 3597 daemon=True) 3598 self.watch_dog.start() 3599 3600 def _abort_watchdog(self): 3601 if not self.eot.is_set(): 3602 self.eot.set() 3603 3604 def abort_watchdog(self): 3605 if hasattr(self, 'watch_dog') and self.watch_dog is not None and hasattr(self, 'eot') and self.eot is not None: 3606 self._abort_watchdog() 3607 if hasattr(self, 'cleaning_process') and self.cleaning_process is not None: 3608 if hasattr(self, 'eof') and self.eof is not None and not self.eof.is_set(): 3609 self.eof.set() 3610 _PythonMultiprocessing._terminate_processes([self.cleaning_process]) 3611 del self.cleaning_process 3612 3613 def is_running(self): 3614 if hasattr(self, 'workers') and self.workers is not None: 3615 return all([w.is_alive() for w in self.workers]) 3616 return False 3617 3618 def close_all_workers(self): 3619 """Close all the subprocess workers""" 3620 if hasattr(self, 'workers') and self.workers is not None: 3621 for w in self.workers: 3622 w.close() 3623 check_interval = get_multiprocessing_timeout_interval() 3624 for w in self.workers: 3625 try: 3626 subprocess_file_descriptor = w.sentinel 3627 st = time.time() 3628 while _PythonMultiprocessing.is_process_alive(w.pid): 3629 time.sleep(0.01) # sleep 10ms, waiting for the subprocess exit 3630 if time.time() - st > check_interval: 3631 logger.warning("Waiting for the subprocess worker [{}] to exit.".format(w.pid)) 3632 st += check_interval 3633 except ValueError as e: 3634 if "process object is closed" in str(e): 3635 continue 3636 raise e 3637 try: 3638 if w.is_alive(): 3639 os.close(subprocess_file_descriptor) 3640 except OSError as e: 3641 # Maybe the file descriptor had been released, so ignore the 'Bad file descriptor' 3642 if "Bad file descriptor" not in str(e): 3643 raise e 3644 3645 # use clear to release the handle which is better than self.workers = None 3646 self.workers.clear() 3647 self.workers = None 3648 self.pids = None 3649 3650 3651class MapDataset(UnionBaseDataset): 3652 """ 3653 The result of applying the Map operation to the input Dataset. 3654 3655 Args: 3656 input_dataset (Dataset): Input Dataset to be mapped. 3657 operations (Union[list[TensorOperation], list[functions]]): A function mapping a nested structure of tensors 3658 to another nested structure of tensor. Default: ``None``. 3659 input_columns (Union[str, list[str]]): List of names of the input columns. 3660 Default: ``None``, the operations will be applied on the first columns in the dataset. 3661 The size of the list should match the number of inputs of the first operation. 3662 output_columns (Union[str, list[str]], optional): List of names of the output columns. 3663 The size of the list should match the number of outputs of the last operation. 3664 Default: ``None``, output columns will be the input columns, i.e., the columns will 3665 be replaced. 3666 num_parallel_workers (int, optional): Number of workers to process the dataset 3667 in parallel. Default: ``None``. 3668 python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This 3669 option could be beneficial if the Python operation is computational heavy. Default: ``False``. 3670 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 3671 Default: ``None``, which means no cache is used. 3672 callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called. Default: ``None``. 3673 max_rowsize(Union[int, list[int]], optional): Maximum size of row in MB that is used for shared memory 3674 allocation to copy data between processes, the total occupied shared memory will increase as 3675 ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set to -1, 3676 shared memory will be dynamically allocated with the actual size of data. This is only used if 3677 ``python_multiprocessing`` is set to True. If it is an int value, it represents ``input_columns`` and 3678 ``output_columns`` use this value as the unit to create shared memory. If it is a list, the first element 3679 represents the ``input_columns`` use this value as the unit to create shared memory, and the second element 3680 represents ``output_columns`` use this value as the unit to create shared memory. Default: 16. 3681 offload (bool, optional): Flag to indicate whether offload is used. Default: ``None``. 3682 """ 3683 3684 def __init__(self, input_dataset, operations=None, input_columns=None, output_columns=None, 3685 num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16, 3686 offload=None): 3687 super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers, cache=cache) 3688 self.operations = to_list(operations) 3689 for op in self.operations: 3690 # user define c_vision.HWC2CHW without parentheses is error 3691 if type(op) == type: # pylint: disable=unidiomatic-typecheck 3692 raise ValueError("Parameter operations's element of method map should be a dataset processing " 3693 "operation instance, but got: {}. It may be missing parentheses for " 3694 "instantiation.".format(op)) 3695 if not isinstance(op, (c_transforms.TensorOperation, py_transforms.PyTensorOperation)) \ 3696 and not callable(op): 3697 raise ValueError("Parameter operations's element of method map should be a python function or " 3698 "class method which should be callable, but got: {}. It doesn't need parentheses " 3699 "for python function or class method.".format(op)) 3700 3701 self.input_columns = to_list(input_columns) 3702 self.output_columns = to_list(output_columns) 3703 3704 # If output_columns were not provided then use input_columns 3705 self.output_columns = self.input_columns if not self.output_columns else self.output_columns 3706 3707 self.python_multiprocessing = python_multiprocessing 3708 self.process_pool = None 3709 3710 self.callbacks = to_list(callbacks) 3711 if isinstance(max_rowsize, int): 3712 self.max_rowsize = [max_rowsize] * 2 3713 else: 3714 self.max_rowsize = max_rowsize 3715 self.offload = offload 3716 3717 def parse(self, children=None): 3718 operations = self.__decompose_callable_operations() 3719 3720 count_old_transforms, count_new_transforms, count_non_data_vision_transforms = \ 3721 self.__count_transforms(operations) 3722 count_pyfunc = self.__count_pyfuncs(operations) 3723 if count_new_transforms + count_pyfunc == len(operations): 3724 prev_op = None 3725 for op in operations: 3726 # skip user added DebugHook to avoid changing to Py-implementation. 3727 if self.__is_debug_hook_op(op): 3728 if prev_op: 3729 # manually set previous_op_name 3730 prev_op_name = self.__parse_op_name(prev_op) 3731 op.set_previous_op_name(prev_op_name) 3732 continue 3733 if op.implementation is None: 3734 if prev_op and prev_op.implementation == Implementation.PY: 3735 op.implementation = Implementation.PY 3736 else: 3737 op.implementation = Implementation.C 3738 prev_op = op 3739 operations = self.__insert_debug_wrapper(operations) 3740 operations = transforms.transforms.Compose.reduce(operations) 3741 elif count_old_transforms + count_pyfunc + count_non_data_vision_transforms == len(operations): 3742 operations = self.__insert_debug_wrapper(operations) 3743 operations = transforms.py_transforms.Compose.reduce(operations) 3744 else: 3745 raise RuntimeError("Mixing old legacy c/py_transforms and new unified transforms is not allowed.") 3746 3747 self.operations = self.__process_final_operations(operations) 3748 self.prepare_multiprocessing() 3749 3750 callbacks = [cb.create_runtime_obj() for cb in self.callbacks] 3751 return cde.MapNode(children[0], self.operations, self.input_columns, self.output_columns, 3752 callbacks, OffloadToManualOffloadMode.get(self.offload), self.process_pool) 3753 3754 def __deepcopy__(self, memodict): 3755 return self.__safe_deepcopy__(memodict, exclude=("operations", "callbacks", "__transfer_dataset__")) 3756 3757 def __del__(self): 3758 if hasattr(self, "process_pool") and self.process_pool is not None: 3759 self.process_pool.terminate() 3760 del self.process_pool 3761 3762 @staticmethod 3763 def __parse_op_name(op): 3764 """ 3765 Utility method to get operation name. 3766 """ 3767 op_name = "" 3768 if isinstance(op, transforms.py_transforms_util.FuncWrapper): 3769 try: 3770 op_name = op.transform.__name__ 3771 except (AttributeError,): 3772 op_name = op.transform.__class__.__name__ 3773 else: 3774 op_name = op.__class__.__name__ 3775 return op_name 3776 3777 @staticmethod 3778 def __construct_debug_hook(previous_op_name=None, is_first_op=False): 3779 """ 3780 Wrap debug hook into FuncWrapper. 3781 """ 3782 inserted_functions = [] 3783 debug_hook_list = _get_debug_hook_list() 3784 if debug_hook_list: 3785 for fn in debug_hook_list: 3786 # making deep copy to allow each debug hook instance hold unique variables 3787 new_fn = copy.deepcopy(fn) 3788 new_fn.set_previous_op_name(previous_op_name) 3789 new_fn.set_is_first(is_first_op) 3790 inserted_func = transforms.py_transforms_util.FuncWrapper(new_fn) 3791 inserted_func.implementation = Implementation.PY 3792 inserted_functions.append(inserted_func) 3793 return inserted_functions 3794 3795 @staticmethod 3796 def __is_debug_hook_op(op): 3797 """ 3798 Check if the op is user added DebugHook and skip it to avoid changing transforms implementation. 3799 """ 3800 if isinstance(op, DebugHook): 3801 if not get_debug_mode(): 3802 raise ValueError("It is not allowed to inject DebugHook object in non-debug mode.") 3803 return True 3804 return False 3805 3806 @staticmethod 3807 def __count_pyfuncs(operations): 3808 """ 3809 Count the number of pyfuncs operations 3810 """ 3811 return sum([1 if isinstance(op, FuncWrapper) else 0 for op in operations]) 3812 3813 @staticmethod 3814 def __count_transforms(operations): 3815 """ 3816 Count the various flavors of transforms operations 3817 """ 3818 # Count the number of old legacy data and vision c_transforms and py_transforms 3819 count_old_transforms = sum( 3820 [1 if "c_transforms" in str(op) 3821 or isinstance(op, (c_transforms.TensorOperation, py_transforms.PyTensorOperation)) 3822 or ("py_transforms" in str(op) and not isinstance(op, FuncWrapper)) 3823 else 0 for op in operations]) 3824 # Count the number of new unified data and vision transforms 3825 count_new_transforms = sum([1 if hasattr(op, "implementation") and not isinstance(op, FuncWrapper) 3826 else 0 for op in operations]) 3827 # Count the number of non-data transforms and non-vision transforms 3828 count_non_data_vision_transforms = sum( 3829 [1 if "text.transforms" in str(op) or "audio.transforms" in str(op) else 0 for op in operations]) 3830 return count_old_transforms, count_new_transforms, count_non_data_vision_transforms 3831 3832 @staticmethod 3833 def __operation_valid_for_multiprocessing(op): 3834 if callable(op) and str(op).find("c_transform") < 0: 3835 return True 3836 return False 3837 3838 @staticmethod 3839 def __process_final_operations(operations): 3840 """ 3841 Build final list of operations 3842 """ 3843 operations_fin = [] 3844 for op in operations: 3845 if hasattr(op, "implementation"): 3846 if op.implementation == Implementation.C and not isinstance(op, (FuncWrapper, ToNumpy)): 3847 operations_fin.append(op.parse()) 3848 elif op.implementation == Implementation.PY: 3849 operations_fin.append(op) 3850 elif isinstance(op, (FuncWrapper, ToNumpy)): 3851 operations_fin.append(op) 3852 else: 3853 raise RuntimeError("Wrong implementation") 3854 else: 3855 if op and getattr(op, 'parse', None): 3856 operations_fin.append(op.parse()) 3857 else: 3858 operations_fin.append(op) 3859 return operations_fin 3860 3861 # Iterator bootstrap will be called on iterator construction. 3862 # A deep copy of Dataset object is created prior of iterator_bootstrap. 3863 # This method will create per iterator process pool and bind pyfunc execution to the pool. 3864 def prepare_multiprocessing(self): 3865 """ 3866 Per iterator bootstrap callback. 3867 """ 3868 if self.python_multiprocessing and platform.system().lower() == 'windows': 3869 logger.warning("Python multiprocessing is not supported on Windows platform.") 3870 return 3871 if self.python_multiprocessing and get_debug_mode(): 3872 logger.warning("Python multiprocessing is not supported in debug mode." 3873 " Ignoring Python multiprocessing for map operation.") 3874 return 3875 if self.python_multiprocessing: 3876 iter_specific_operations = [] 3877 callable_list = [] 3878 3879 # If user didn't specify num_parallel_workers, set it to default 3880 if self.num_parallel_workers is None: 3881 self.num_parallel_workers = get_num_parallel_workers() 3882 3883 # Pass #1, look for Python callables and build list 3884 for op in self.operations: 3885 # our c transforms is now callable and should not be run in Python multithreading 3886 if MapDataset.__operation_valid_for_multiprocessing(op): 3887 callable_list.append(op) 3888 3889 if callable_list: 3890 self.process_pool = _PythonMultiprocessing(str(self), self.num_parallel_workers, callable_list, 3891 self.max_rowsize) 3892 # Pass #2 3893 idx = 0 3894 for op in self.operations: 3895 # our c transforms is now callable and should not be run in Python multithreading 3896 if MapDataset.__operation_valid_for_multiprocessing(op): 3897 # Wrap Python callable into _PythonCallable 3898 iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool)) 3899 idx += 1 3900 else: 3901 # CPP ops remain the same 3902 iter_specific_operations.append(op) 3903 self.operations = iter_specific_operations 3904 3905 def __insert_debug_wrapper(self, operations): 3906 """ 3907 Insert DebuggerWrapper before and after each op if debug mode is on. 3908 """ 3909 if not get_debug_mode(): 3910 return operations 3911 first_op_name = self.__parse_op_name(operations[0]) 3912 inserted_operations = self.__construct_debug_hook(first_op_name, is_first_op=True) 3913 for op in operations: 3914 inserted_operations.append(op) 3915 op_name = self.__parse_op_name(op) 3916 inserted_operations.extend(self.__construct_debug_hook(op_name)) 3917 return inserted_operations 3918 3919 def __decompose_callable_operations(self): 3920 """ 3921 Decompose operations and build list of old legacy ops which are callable 3922 """ 3923 decomposed_operations = transforms.transforms.Compose.decompose(self.operations) 3924 operations = [] 3925 for op in decomposed_operations: 3926 if callable(op) and not hasattr(op, "implementation") and str(op).find( 3927 "c_transform") < 0 and not isinstance(op, c_transforms.TensorOperation) and \ 3928 not isinstance(op, py_transforms.PyTensorOperation): 3929 op = transforms.py_transforms_util.FuncWrapper(op) 3930 operations.append(op) 3931 return operations 3932 3933 3934class FilterDataset(UnionBaseDataset): 3935 """ 3936 The result of applying filter predicate to the input Dataset. 3937 3938 Args: 3939 input_dataset (Dataset): Input Dataset to be mapped. 3940 predicate (callable): Python callable which returns a boolean value. If False then filter the element. 3941 input_columns (Union[str, list[str]], optional): List of names of the input columns. 3942 Default: ``None``, the predicate will be applied to all columns in the dataset. 3943 num_parallel_workers (int, optional): Number of workers to process the dataset 3944 in parallel. Default: ``None``. 3945 """ 3946 3947 def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None): 3948 super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers) 3949 self.predicate = lambda *args: bool(predicate(*args)) 3950 self.input_columns = to_list(input_columns) 3951 3952 def parse(self, children=None): 3953 return cde.FilterNode(children[0], self.predicate, self.input_columns) 3954 3955 3956class RepeatDataset(UnionBaseDataset): 3957 """ 3958 The result of applying Repeat operation to the input Dataset. 3959 3960 Args: 3961 input_dataset (Dataset): Input Dataset to be repeated. 3962 count (int): Number of times the dataset will be repeated. Default: -1, repeat indefinitely. 3963 """ 3964 3965 def __init__(self, input_dataset, count): 3966 super().__init__(children=input_dataset) 3967 self.count = replace_none(count, -1) 3968 3969 def parse(self, children=None): 3970 return cde.RepeatNode(children[0], self.count) 3971 3972 3973class SkipDataset(UnionBaseDataset): 3974 """ 3975 The result of applying Skip operation to the input Dataset. 3976 3977 Args: 3978 input_dataset (Dataset): Input dataset to have elements skipped. 3979 count (int): Number of elements to be skipped in the dataset. 3980 """ 3981 3982 def __init__(self, input_dataset, count): 3983 super().__init__(input_dataset) 3984 self.count = count 3985 3986 def parse(self, children=None): 3987 return cde.SkipNode(children[0], self.count) 3988 3989 3990class TakeDataset(UnionBaseDataset): 3991 """ 3992 The result of applying Take operation to the input Dataset. 3993 3994 Args: 3995 input_dataset (Dataset): Input Dataset to have elements taken from. 3996 count (int): Number of elements to be taken from the dataset. 3997 """ 3998 3999 def __init__(self, input_dataset, count): 4000 super().__init__(children=input_dataset) 4001 self.count = count 4002 4003 def parse(self, children=None): 4004 return cde.TakeNode(children[0], self.count) 4005 4006 4007class ZipDataset(UnionBaseDataset): 4008 """ 4009 The result of applying Zip operation to the input Dataset. 4010 4011 Args: 4012 datasets (tuple): A tuple of datasets to be zipped together. 4013 4014 Raises: 4015 TypeError: If dataset is not an instance of Dataset. 4016 """ 4017 4018 def __init__(self, datasets): 4019 super().__init__(children=datasets) 4020 4021 def parse(self, children=None): 4022 return cde.ZipNode(children) 4023 4024 def is_sync(self): 4025 return any([c.is_sync() for c in self.children]) 4026 4027 4028class ConcatDataset(UnionBaseDataset): 4029 """ 4030 The result of applying Concat operation to the input Dataset. 4031 4032 Args: 4033 datasets (list): A list of datasets to be concatenated together. 4034 4035 Raises: 4036 TypeError: If dataset is not an instance of Dataset. 4037 ValueError: If there is no samples in the one of the datasets. 4038 """ 4039 4040 def __init__(self, datasets): 4041 super().__init__(children=datasets) 4042 for dataset in datasets: 4043 if not isinstance(dataset, Dataset): 4044 raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset)) 4045 self.datasets = datasets 4046 self._sampler = samplers.SequentialSampler(num_samples=None) 4047 4048 self.children_sizes_ = [c.get_dataset_size() for c in self.children] 4049 child_index = 0 4050 for item in self.children_sizes_: 4051 if item == 0: 4052 raise ValueError("There are no samples in the dataset number %d. Please make sure there are " 4053 "valid samples in the dataset." % child_index) 4054 child_index += 1 4055 4056 self._children_sizes = self.children_sizes_.copy() 4057 4058 # _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes 4059 # whether the dataset is mappable. The second element of pair is length of the dataset 4060 self._children_flag_and_nums = [] 4061 4062 # _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize 4063 # the valid position of the dataset corresponding to the subscript when sampling 4064 self._children_start_end_index_ = [] 4065 for index, child in enumerate(self.children): 4066 tem_list = [-1, -1] 4067 self._children_start_end_index_.append(tem_list) 4068 dataset_len = self.children_sizes_[index] 4069 4070 from mindspore.dataset.engine.datasets_user_defined import GeneratorDataset 4071 if isinstance(child, GeneratorDataset) and not hasattr(child.source, "__getitem__"): 4072 dataset_len = 0 4073 self.children_sizes_[index] = 0 4074 4075 if isinstance(child, MappableDataset): 4076 self._children_flag_and_nums.append((0, dataset_len)) 4077 else: 4078 self._children_flag_and_nums.append((1, dataset_len)) 4079 4080 def parse(self, children=None): 4081 return cde.ConcatNode(children, self._sampler, self._children_flag_and_nums, self._children_start_end_index_, 4082 self._children_sizes) 4083 4084 def use_sampler(self, sampler): 4085 """ 4086 Set the distributedSampler to concat dataset 4087 4088 Args: 4089 sampler (Sampler): The sampler to use for the current dataset. 4090 Currently supported: DistributedSampler. 4091 4092 Raises: 4093 TypeError: If the sampler is not an instance of DistributedSampler 4094 ValueError: If the parameter shuffle of sampler is True 4095 ValueError: If the parameter NumSamples of sampler is not None. 4096 ValueError: If num_shards <=0. 4097 """ 4098 if not isinstance(sampler, (samplers.DistributedSampler, samplers.RandomSampler)): 4099 raise TypeError("The parameter %s of concat must be DistributedSampler or RandomSampler!" % sampler) 4100 4101 if isinstance(sampler, samplers.RandomSampler): 4102 if sampler.replacement: 4103 raise ValueError("The parameter replacement of RandomSampler must be False!") 4104 4105 if sampler.get_num_samples() is not None: 4106 raise ValueError("The parameter num_samples of RandomSampler is not support to be set!") 4107 4108 self._sampler = sampler 4109 self._children_sizes = [c.get_dataset_size() for c in self.children] 4110 4111 # Recursive access to other child concat nodes 4112 def set_child(node): 4113 for c in node.children: 4114 if isinstance(c, ConcatDataset): 4115 c.use_sampler(sampler) 4116 set_child(c) 4117 set_child(self) 4118 4119 return 4120 4121 if sampler.is_shuffled(): 4122 raise ValueError("The parameter shuffle of DistributedSampler must be False!") 4123 4124 if sampler.num_shards <= 0: 4125 raise ValueError("The parameter num_shards of DistributedSampler must be positive int!") 4126 4127 if sampler.get_num_samples() is not None: 4128 raise ValueError("The parameter num_samples of DistributedSampler is not support to be set!") 4129 4130 self.dataset_size = None 4131 4132 self._sampler = sampler 4133 cumulative_samples_nums = 0 4134 for index, child in enumerate(self.children): 4135 if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None: 4136 raise ValueError("The parameter NumSamples of %s is not support to be set!" % child) 4137 4138 if isinstance(child, (BatchDataset, PaddedBatchDataset)): 4139 raise TypeError("The parameter %s of concat must not be BatchDataset or PaddedBatchDataset!" % child) 4140 4141 # if child is mappable and the length is greater than 0 4142 if not self._children_flag_and_nums[index][0] and self._children_flag_and_nums[index][1]: 4143 4144 tem_value = cumulative_samples_nums + self._children_flag_and_nums[index][1] 4145 4146 if not self._children_flag_and_nums[index][1] >= sampler.num_shards: 4147 if tem_value < sampler.num_shards: 4148 self._children_start_end_index_[index][0] = cumulative_samples_nums 4149 self._children_start_end_index_[index][1] = tem_value 4150 else: 4151 self._children_start_end_index_[index][0] = cumulative_samples_nums 4152 self._children_start_end_index_[index][1] = tem_value % sampler.num_shards 4153 4154 tem_sampler = copy.deepcopy(sampler) 4155 tem_sampler.set_offset(cumulative_samples_nums) 4156 child.use_sampler(tem_sampler) 4157 4158 cumulative_samples_nums += self.children_sizes_[index] 4159 cumulative_samples_nums %= sampler.num_shards 4160 4161 4162class RenameDataset(UnionBaseDataset): 4163 """ 4164 The result of applying Rename operation to the input Dataset. 4165 4166 Args: 4167 input_dataset (Dataset): Input Dataset to be Renamed. 4168 input_columns (Union[str, list[str]]): List of names of the input columns. 4169 output_columns (Union[str, list[str]]): List of names of the output columns. 4170 """ 4171 4172 def __init__(self, input_dataset, input_columns, output_columns): 4173 super().__init__(children=input_dataset) 4174 self.input_column_names = to_list(input_columns) 4175 self.output_column_names = to_list(output_columns) 4176 4177 def parse(self, children=None): 4178 return cde.RenameNode(children[0], self.input_column_names, self.output_column_names) 4179 4180 4181def to_list(items): 4182 if items is None: 4183 return [] 4184 if isinstance(items, tuple): 4185 return list(items) 4186 if not isinstance(items, list): 4187 return [items] 4188 return items 4189 4190 4191class ProjectDataset(UnionBaseDataset): 4192 """ 4193 The result of applying Project operation to the input Dataset. 4194 4195 Args: 4196 input_dataset (Dataset): Input Dataset to be Projected. 4197 columns (Union[str, list[str]]): List of names of the columns to project. 4198 """ 4199 4200 def __init__(self, input_dataset, columns): 4201 super().__init__(children=input_dataset) 4202 self.columns = to_list(columns) 4203 4204 def parse(self, children=None): 4205 return cde.ProjectNode(children[0], self.columns) 4206 4207 4208class _ToDevice: 4209 """ 4210 Internal class to handle sending data to device. 4211 """ 4212 4213 def __init__(self, dataset, num_epochs): 4214 if get_debug_mode(): 4215 logger.error("MindData debugger cannot be used in dataset sink mode. Please manually turn off " 4216 "sink mode and try debugger again.") 4217 ir_tree, self.api_tree = dataset.create_ir_tree() 4218 4219 self._runtime_context = cde.PythonRuntimeContext() 4220 self._runtime_context.Init() 4221 self._to_device = cde.ToDevice(num_epochs) 4222 if dataset.get_init_step() != 0: 4223 init_step = dataset.get_init_step() 4224 dataset_size = dataset.get_dataset_size() 4225 self._to_device.Init(ir_tree, init_step, dataset_size) 4226 else: 4227 self._to_device.Init(ir_tree, 0, -1) 4228 self._runtime_context.AssignConsumer(self._to_device) 4229 4230 ITERATORS_LIST.append(weakref.ref(self)) 4231 _unset_iterator_cleanup() 4232 4233 def send(self): 4234 self._to_device.Send() 4235 4236 def stop_send(self): 4237 """ 4238 send stop send signal to pipeline, it is used when end of sequence is sent at the epoch end. 4239 """ 4240 self._to_device.StopSend() 4241 4242 def continue_send(self): 4243 """ 4244 send continue send signal to pipeline, it is used when end of sequence is sent at the epoch end. 4245 """ 4246 self._to_device.ContinueSend() 4247 4248 def get_data_info(self): 4249 """ 4250 Get type and shape of current batch. 4251 """ 4252 return self._to_device.GetDataInfo() 4253 4254 def get_mbuf_queue_size(self): 4255 """ 4256 Get element numbers inside mbuf. 4257 """ 4258 return self._to_device.GetMbufQueueSize() 4259 4260 def get_send_info(self): 4261 """ 4262 In sink mode, it returns the send information of dataset at this moment. 4263 Send information includes number of send batches, time summary of fetching data on host 4264 and time summary of sending data. 4265 """ 4266 return self._to_device.GetSendInfo() 4267 4268 def release(self): 4269 """ 4270 Manually terminate Device Queue instead of relying on out of scope destruction. 4271 """ 4272 if hasattr(self, '_runtime_context') and self._runtime_context: 4273 if hasattr(self, '_to_device') and self._to_device: 4274 self._runtime_context.Terminate() 4275 del self._to_device 4276 del self._runtime_context 4277 4278 def __deepcopy__(self, memodict): 4279 return self 4280 4281 def get_offload_model(self, col_names): 4282 """ 4283 Get offload model containing removed offload ops from pipeline. 4284 """ 4285 offload_model = GetOffloadModel(self._to_device, col_names) 4286 return offload_model 4287 4288 def _reset(self, step, dataset_size): 4289 self._to_device.Reset(step, dataset_size) 4290 4291 4292class TransferDataset(Dataset): 4293 """ 4294 The result of applying TDT operation to the input Dataset. 4295 4296 Args: 4297 input_dataset (Dataset): Input Dataset to be transferred. 4298 send_epoch_end (bool, optional): Whether to send end of sequence to device or not. Default: ``True``. 4299 create_data_info_queue (bool, optional): Whether to create queue which stores 4300 types and shapes of data or not. Default: ``False``. 4301 4302 Raises: 4303 TypeError: If device_type is empty. 4304 ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'. 4305 RuntimeError: If dataset is unknown. 4306 """ 4307 4308 def __init__(self, input_dataset, send_epoch_end=True, create_data_info_queue=False, queue_name=""): 4309 super().__init__(children=input_dataset) 4310 if queue_name == "": 4311 self.queue_name = str(uuid.uuid1()) 4312 logger.info(f"queue_name is newly generated. value is {self.queue_name}") 4313 else: 4314 self.queue_name = queue_name 4315 logger.info(f"queue_name is read from compile cache. value is {self.queue_name}") 4316 self.device_type = context.get_context("device_target") if context else "CPU" 4317 self.device_id = context.get_context("device_id") if context else 0 4318 4319 self._send_epoch_end = replace_none(send_epoch_end, True) 4320 self._create_data_info_queue = create_data_info_queue 4321 self._to_device = None 4322 self.column_name = input_dataset.get_col_names() 4323 4324 def parse(self, children=None): 4325 total_batch = 0 4326 if hasattr(self.children[0], "__total_batch__"): 4327 total_batch = self.children[0].__total_batch__ 4328 check_total_batch(total_batch) 4329 return cde.DataQueueNode(children[0], self.queue_name, self.device_type, self.device_id, self._send_epoch_end, 4330 total_batch, self._create_data_info_queue) 4331 4332 def create_dict_iterator(self, num_epochs=-1, output_numpy=False): 4333 raise RuntimeError("TransferDataset is not iterable.") 4334 4335 def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True): 4336 raise RuntimeError("TransferDataset is not iterable.") 4337 4338 def __iter__(self): 4339 raise RuntimeError("TransferDataset is not iterable.") 4340 4341 def output_shapes(self): 4342 raise RuntimeError("TransferDataset does not support obtaining output_shapes.") 4343 4344 def output_types(self): 4345 raise RuntimeError("TransferDataset does not support obtaining output_types.") 4346 4347 @check_to_device_send 4348 def send(self, num_epochs=-1): 4349 """ 4350 Send to device 4351 """ 4352 if Dataset._noop_mode(): 4353 return 4354 if self._to_device is not None: 4355 del self._to_device 4356 self._to_device = _ToDevice(self, num_epochs) 4357 self._to_device.send() 4358 4359 def stop_send(self): 4360 if self._to_device is not None: 4361 self._to_device.stop_send() 4362 4363 def continue_send(self): 4364 if self._to_device is not None: 4365 self._to_device.continue_send() 4366 4367 def get_data_info(self): 4368 """ 4369 Get type and shape of current batch 4370 """ 4371 if self._to_device is not None: 4372 return self._to_device.get_data_info() 4373 raise RuntimeError("Calling get_data_info with bad state.") 4374 4375 def get_mbuf_queue_size(self): 4376 """ 4377 Get element numbers inside mbuf. 4378 """ 4379 if self._to_device is not None: 4380 return self._to_device.get_mbuf_queue_size() 4381 raise RuntimeError("Device queue is not init, call get_mbuf_queue_size failed.") 4382 4383 def get_send_info(self): 4384 """ 4385 In sink mode, it returns the send information of dataset at this moment. 4386 Send information includes number of send batches, time summary of fetching data on host 4387 and time summary of sending data. 4388 """ 4389 if self._to_device is not None: 4390 return self._to_device.get_send_info() 4391 raise RuntimeError("Calling get_send_info with bad state, data queue is not initialized.") 4392 4393 def get_offload_model(self): 4394 if self._to_device is not None: 4395 return self._to_device.get_offload_model(self.column_name) 4396 4397 raise RuntimeError("get_offload_model, _to_device is None") 4398 4399 def release(self): 4400 """ 4401 Manually terminate Device Queue instead of relying on out of scope destruction. 4402 """ 4403 if self._to_device is not None: 4404 self._to_device.release() 4405 4406 def _reset(self, step, dataset_size): 4407 if self._to_device is not None: 4408 logger.info("Reset the dataset pipeline to step: " + str(step) + ", epoch: " + str(step // dataset_size)) 4409 self._to_device._reset(step, dataset_size) # pylint: disable=protected-access 4410 4411 4412class Schema: 4413 """ 4414 Class to represent a schema of a dataset. 4415 4416 Args: 4417 schema_file (str): Path of the schema file. Default: ``None``. 4418 4419 Raises: 4420 RuntimeError: If schema file failed to load. 4421 4422 Examples: 4423 >>> import mindspore.dataset as ds 4424 >>> from mindspore import dtype as mstype 4425 >>> 4426 >>> # Create schema; specify column name, mindspore.dtype and shape of the column 4427 >>> schema = ds.Schema() 4428 >>> schema.add_column(name='col1', de_type=mstype.int64, shape=[2]) 4429 """ 4430 4431 @check_schema 4432 def __init__(self, schema_file=None): 4433 self.schema_file = replace_none(schema_file, "") 4434 self.cpp_schema = cde.SchemaObj(self.schema_file) 4435 4436 @check_add_column 4437 def add_column(self, name, de_type, shape=None): 4438 """ 4439 Add new column to the schema. 4440 4441 Args: 4442 name (str): The new name of the column. 4443 de_type (str): Data type of the column. 4444 shape (list[int], optional): Shape of the column. 4445 Default: ``None``, [-1] which is an unknown shape of rank 1. 4446 4447 Raises: 4448 ValueError: If column type is unknown. 4449 4450 Examples: 4451 >>> import mindspore.dataset as ds 4452 >>> from mindspore import dtype as mstype 4453 >>> 4454 >>> schema = ds.Schema() 4455 >>> schema.add_column('col_1d', de_type=mstype.int64, shape=[2]) 4456 """ 4457 if isinstance(de_type, typing.Type): 4458 de_type = mstype_to_detype(de_type) 4459 col_type = str(de_type) 4460 else: 4461 col_type = str(cde.DataType(de_type)) 4462 if shape is None: 4463 self.cpp_schema.add_column(name, col_type) 4464 else: 4465 self.cpp_schema.add_column(name, col_type, shape) 4466 4467 def parse_columns(self, columns): 4468 """ 4469 Parse the columns and add it to self. 4470 4471 Args: 4472 columns (Union[dict, list[dict], tuple[dict]]): Dataset attribute information, decoded from schema file. 4473 4474 - list[dict], `name` and `type` must be in keys, `shape` optional. 4475 4476 - dict, columns.keys() as name, columns.values() is dict, and `type` inside, `shape` optional. 4477 4478 Raises: 4479 RuntimeError: If failed to parse columns. 4480 RuntimeError: If column's name field is missing. 4481 RuntimeError: If column's type field is missing. 4482 4483 Examples: 4484 >>> from mindspore.dataset import Schema 4485 >>> schema = Schema() 4486 >>> columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]}, 4487 ... {'name': 'label', 'type': 'int8', 'shape': [1]}] 4488 >>> schema.parse_columns(columns1) 4489 >>> columns2 = {'image': {'shape': [3, 3], 'type': 'int8'}, 'label': {'shape': [1], 'type': 'int8'}} 4490 >>> schema.parse_columns(columns2) 4491 """ 4492 self.cpp_schema.parse_columns(json.dumps(columns, indent=2)) 4493 4494 def to_json(self): 4495 """ 4496 Get a JSON string of the schema. 4497 4498 Returns: 4499 str, JSON string of the schema. 4500 4501 Examples: 4502 >>> from mindspore.dataset import Schema 4503 >>> from mindspore import dtype as mstype 4504 >>> 4505 >>> schema = Schema() 4506 >>> schema.add_column('col_1d', de_type=mstype.int64, shape=[2]) 4507 >>> json = schema.to_json() 4508 """ 4509 return self.cpp_schema.to_json() 4510 4511 def from_json(self, json_obj): 4512 """ 4513 Get schema file from JSON object. 4514 4515 Args: 4516 json_obj(dictionary): Object of JSON parsed. 4517 4518 Raises: 4519 RuntimeError: if there is unknown item in the object. 4520 RuntimeError: if dataset type is missing in the object. 4521 RuntimeError: if columns are missing in the object. 4522 4523 Examples: 4524 >>> import json 4525 >>> from mindspore.dataset import Schema 4526 >>> 4527 >>> with open("/path/to/schema_file", "r") as file: 4528 ... json_obj = json.load(file) 4529 ... schema = Schema() 4530 ... schema.from_json(json_obj) 4531 """ 4532 self.cpp_schema.from_string(json.dumps(json_obj, indent=2)) 4533 4534 def __str__(self): 4535 return self.to_json() 4536 4537 @staticmethod 4538 def get_num_rows(schema): 4539 schema_obj = schema 4540 if not isinstance(schema_obj, Schema): 4541 schema_obj = Schema(schema_obj) 4542 return schema_obj.cpp_schema.get_num_rows() 4543 4544 4545class DeserializedDataset(Dataset): 4546 def __init__(self, input_obj): 4547 super().__init__() 4548 self.input_obj = input_obj 4549 4550 def parse(self, children=None): 4551 if isinstance(self.input_obj, dict): 4552 json_str = json.dumps(self.input_obj) 4553 return cde.Dataset.from_json_string(json_str) 4554 return cde.Dataset.from_json_file(self.input_obj) 4555