1# Copyright 2019-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16This dataset module supports various formats of datasets, including ImageNet, TFData, 17MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with 18high performance and parses data precisely. Some of the operations that are 19provided to users to preprocess data include shuffle, batch, repeat, map, and zip. 20""" 21import atexit 22import glob 23import json 24import math 25import os 26import signal 27import stat 28import time 29import uuid 30import multiprocessing 31from multiprocessing.pool import RUN 32import queue 33from enum import Enum 34from functools import partial 35from importlib import import_module 36import sys 37import threading 38 39import copy 40import weakref 41import platform 42import psutil 43import numpy as np 44from scipy.io import loadmat 45from PIL import Image 46 47import mindspore._c_dataengine as cde 48from mindspore._c_expression import typing 49 50from mindspore import Tensor 51from mindspore import log as logger 52from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched 53from mindspore.parallel._utils import _get_device_num 54 55import mindspore.dataset.transforms.py_transforms as py_transforms 56 57from . import samplers 58from .iterators import DictIterator, TupleIterator, DummyIterator, check_iterator_cleanup, _set_iterator_cleanup, \ 59 ITERATORS_LIST, _unset_iterator_cleanup 60from .queue import _SharedQueue 61from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ 62 check_rename, check_numpyslicesdataset, check_device_send, check_take, check_project, check_imagefolderdataset, \ 63 check_mnist_cifar_dataset, check_manifestdataset, check_tfrecorddataset, check_vocdataset, check_cocodataset, \ 64 check_celebadataset, check_minddataset, check_generatordataset, check_sync_wait, check_zip_dataset, \ 65 check_add_column, check_textfiledataset, check_concat, check_random_dataset, check_split, \ 66 check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \ 67 check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \ 68 check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \ 69 check_sbu_dataset 70from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \ 71 get_prefetch_size 72from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist 73from ..core.validator_helpers import replace_none 74from ..core.py_util_helpers import ExceptionHandler 75from ..transforms.py_transforms_util import FuncWrapper 76 77try: 78 context = import_module("mindspore.context") 79except ModuleNotFoundError: 80 context = None 81 82 83class Shuffle(str, Enum): 84 GLOBAL: str = "global" 85 FILES: str = "files" 86 INFILE: str = "infile" 87 88 89ShuffleToShuffleMode = {Shuffle.FILES: cde.ShuffleMode.FILES, 90 Shuffle.GLOBAL: cde.ShuffleMode.GLOBAL, 91 Shuffle.INFILE: cde.ShuffleMode.INFILE} 92 93 94def shuffle_to_shuffle_mode(shuffle): 95 """class Shuffle Enum to int""" 96 shuffle_mode = cde.ShuffleMode.GLOBAL # Global shuffle 97 if not isinstance(shuffle, Shuffle): 98 if shuffle is None or shuffle: 99 shuffle_mode = cde.ShuffleMode.GLOBAL # Global shuffle 100 else: 101 shuffle_mode = cde.ShuffleMode.FALSE # No shuffle 102 else: 103 shuffle_mode = ShuffleToShuffleMode[shuffle] 104 return shuffle_mode 105 106 107def shuffle_to_bool(shuffle): 108 """class Shuffle Enum to bool""" 109 shuffle_bool = True 110 if not isinstance(shuffle, Shuffle): 111 if shuffle is None: 112 shuffle_bool = None 113 elif shuffle: 114 shuffle_bool = True 115 else: 116 shuffle_bool = False 117 else: 118 shuffle_bool = True 119 return shuffle_bool 120 121 122@check_zip 123def zip(datasets): 124 """ 125 Zip the datasets in the input tuple of datasets. 126 127 Args: 128 datasets (tuple of class Dataset): A tuple of datasets to be zipped together. 129 The number of datasets must be more than 1. 130 131 Returns: 132 ZipDataset, dataset zipped. 133 134 Raises: 135 ValueError: If the number of datasets is 1. 136 TypeError: If datasets is not a tuple. 137 138 Examples: 139 >>> # Create a dataset which is the combination of dataset_1 and dataset_2 140 >>> dataset = ds.zip((dataset_1, dataset_2)) 141 """ 142 if len(datasets) <= 1: 143 raise ValueError( 144 "Can't zip empty or just one dataset!") 145 for dataset in datasets: 146 if not isinstance(dataset, Dataset): 147 raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset)) 148 return ZipDataset(datasets) 149 150 151def _get_operator_process(): 152 """ 153 Inner implemented method, mainly for passing sub-process id in C layer 154 155 Returns: 156 dict, mapping dict of operator id and corresponding process id. 157 """ 158 global _OP_PROCESS 159 process_info = _OP_PROCESS 160 op_process = dict() 161 keys = process_info.keys() 162 fetched_all = True 163 for key in keys: 164 op_process[key] = list(process_info[key][1]) 165 item_full = (len(process_info[key][1]) == process_info[key][0]) 166 fetched_all = fetched_all and item_full 167 return op_process, fetched_all 168 169 170def _set_dataset_permissions(file_name, num_files): 171 """ 172 set saved dataset files' permissions to 600 173 the rule of dataset filenames should be the same as those in C++. 174 """ 175 num_digits = len(str(num_files - 1)) 176 if num_files == 1: 177 paths = [file_name] 178 else: 179 paths = ["{}{}".format(file_name, str(x).rjust(num_digits, '0')) for x in range(num_files)] 180 181 for item in paths: 182 if os.path.exists(item): 183 os.chmod(item, stat.S_IRUSR | stat.S_IWUSR) 184 index_file = item + ".db" 185 if os.path.exists(index_file): 186 os.chmod(index_file, stat.S_IRUSR | stat.S_IWUSR) 187 188 189class Dataset: 190 """ 191 Abstract class to represent a dataset in DataEngine's data pipeline. 192 193 This class is the base class of SourceDataset and Dataset, and represents 194 a node in the data flow graph. 195 196 Args: 197 num_parallel_workers (int, optional): Number of workers to process the dataset in parallel 198 (default=None). 199 """ 200 201 def __init__(self, children=None, num_parallel_workers=None, cache=None): 202 # Note: children and parent are internal variables, not recommended for external using. 203 self.children = replace_none(children, []) 204 if isinstance(self.children, tuple): 205 self.children = list(self.children) 206 if not isinstance(self.children, list): 207 self.children = [self.children] 208 209 self.parent = [] 210 for child in self.children: 211 child.parent.append(weakref.ref(self)) 212 self.num_parallel_workers = num_parallel_workers 213 self.cache = cache 214 215 self._device_iter = 0 216 self._input_indexs = () 217 self.saved_output_types = None 218 self.saved_output_shapes = None 219 self.dynamic_setting = [False, None] 220 self.saved_min_shapes = None 221 self.saved_max_shapes = None 222 self._col_names = None 223 self.dataset_size = None 224 self._batch_size = None 225 self._num_classes = None 226 self._repeat_count = None 227 self._class_indexing = None 228 self._sync = False 229 230 def create_ir_tree(self): 231 """ 232 Internal method to build an IR tree. 233 234 Returns: 235 DatasetNode, the root node of the IR tree. 236 Dataset, the root dataset of the IR tree. 237 """ 238 parent = self.parent 239 self.parent = [] 240 dataset = copy.deepcopy(self) 241 global _OP_NAME 242 _OP_NAME = Dataset._get_operator_id(dataset) 243 ir_tree = dataset.parse_tree() 244 self.parent = parent 245 _init_device_info() 246 return ir_tree, dataset 247 248 def close_pool(self): 249 """ 250 Close multiprocessing pool in dataset. If you are familiar with multiprocessing library, you can regard this 251 as a destructor for a processingPool object. 252 """ 253 if hasattr(self, 'process_pool') and self.process_pool is not None: 254 self.process_pool.close() 255 for child in self.children: 256 child.close_pool() 257 258 def notify_watchdog(self): 259 if hasattr(self, 'sample_fn') and self.sample_fn is not None: 260 if self.sample_fn.multi_process: 261 self.sample_fn._abort_watchdog() # pylint: disable=W0212 262 if hasattr(self, 'watch_dog') and self.watch_dog is not None and hasattr(self, 'eot') and self.eot is not None: 263 self._abort_watchdog() 264 for child in self.children: 265 child.notify_watchdog() 266 267 @staticmethod 268 def _get_operator_id(dataset): 269 """ 270 Internal method to iterate the tree and obtain op_id of each operator. 271 272 Returns: 273 Dataset, the root dataset of the tree. 274 """ 275 op_name = dict() 276 generator_process = dict() 277 op_name[str(dataset)] = 0 278 op_id = 1 279 280 def process_name(datasets, operator_id): 281 if not datasets: 282 return 0 283 temp = [] 284 for item in datasets: 285 for d in item.children: 286 temp.append(d) 287 op_name[str(d)] = operator_id 288 if isinstance(d, GeneratorDataset) and d.sample_fn and d.sample_fn.pids: 289 generator_process[operator_id] = [d.num_parallel_workers, set(d.sample_fn.pids)] 290 291 operator_id = operator_id + 1 292 return process_name(temp, operator_id) 293 294 process_name([dataset], op_id) 295 if generator_process: 296 global _OP_PROCESS 297 _OP_PROCESS.update(generator_process) 298 return op_name 299 300 def parse_tree(self): 301 """ 302 Internal method to parse the API tree into an IR tree. 303 304 Returns: 305 DatasetNode, the root node of the IR tree. 306 """ 307 if len(self.parent) > 1: 308 raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") 309 ir_children = [d.parse_tree() for d in self.children] 310 # Bootstrap can only be performed on a copy of the original dataset node. 311 # Bootstrap on original dataset node will make all iterators share the same process pool 312 self.iterator_bootstrap() 313 ir_node = self.parse(ir_children) 314 ir_node = self.post_parse(ir_node) 315 return ir_node 316 317 def __safe_deepcopy__(self, memodict, exclude=()): 318 if id(self) in memodict: 319 return memodict[id(self)] 320 cls = self.__class__ 321 new_op = cls.__new__(cls) 322 memodict[id(self)] = new_op 323 for arg, value in self.__dict__.items(): 324 if arg in exclude: 325 setattr(new_op, arg, value) 326 else: 327 try: 328 setattr(new_op, arg, copy.deepcopy(value, memodict)) 329 except TypeError: 330 setattr(new_op, arg, value) 331 return new_op 332 333 def iterator_bootstrap(self): 334 pass 335 336 @staticmethod 337 def _noop_mode(): 338 if _is_role_sched() or _is_role_pserver(): 339 return True 340 return False 341 342 def __add__(self, datasets): 343 return self.concat(datasets) 344 345 def to_json(self, filename=""): 346 """ 347 Serialize a pipeline into JSON string and dump into file if filename is provided. 348 349 Args: 350 filename (str): filename of JSON file to be saved as. 351 352 Returns: 353 str, JSON string of the pipeline. 354 """ 355 ir_tree, _ = self.create_ir_tree() 356 return json.loads(ir_tree.to_json(filename)) 357 358 @check_bucket_batch_by_length 359 def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes, element_length_function=None, 360 pad_info=None, pad_to_bucket_boundary=False, drop_remainder=False): 361 """ 362 Bucket elements according to their lengths. Each bucket will be padded and batched when 363 they are full. 364 365 A length function is called on each row in the dataset. The row is then 366 bucketed based on its length and bucket boundaries. When a bucket reaches its 367 corresponding size specified in bucket_batch_sizes, the entire bucket will be 368 padded according to batch_info, and then form a batch. 369 Each batch will be full, except one special case: the last batch for each bucket may not be full. 370 371 Args: 372 column_names (list[str]): Columns passed to element_length_function. 373 bucket_boundaries (list[int]): A list consisting of the upper boundaries 374 of the buckets. Must be strictly increasing. If there are n boundaries, 375 n+1 buckets are created: One bucket for [0, bucket_boundaries[0]), one 376 bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each 377 0<i<n-1, and last bucket for [bucket_boundaries[n-1], inf). 378 bucket_batch_sizes (list[int]): A list consisting of the batch sizes for 379 each bucket. Must contain len(bucket_boundaries)+1 elements. 380 element_length_function (Callable, optional): A function that takes in 381 M arguments where M = len(column_names) and returns an integer. If no value 382 provided, parameter M the len(column_names) must be 1, and the size of the first 383 dimension of that column will be taken as the length (default=None). 384 pad_info (dict, optional): The information about how to batch each column. The key 385 corresponds to the column name, and the value must be a tuple of 2 elements. 386 The first element corresponds to the shape to pad to, and the second 387 element corresponds to the value to pad with. If a column is not 388 specified, then that column will be padded to the longest in the current 389 batch, and 0 will be used as the padding value. Any None dimensions will 390 be padded to the longest in the current batch, unless if 391 pad_to_bucket_boundary is True. If no padding is wanted, set pad_info 392 to None (default=None). 393 pad_to_bucket_boundary (bool, optional): If True, will pad each None 394 dimension in pad_info to the bucket_boundary minus 1. If there are any 395 elements that fall into the last bucket, an error will occur 396 (default=False). 397 drop_remainder (bool, optional): If True, will drop the last batch for each 398 bucket if it is not a full batch (default=False). 399 400 Returns: 401 BucketBatchByLengthDataset, dataset bucketed and batched by length. 402 403 Examples: 404 >>> # Create a dataset where certain counts rows are combined into a batch 405 >>> # and drops the last incomplete batch if there is one. 406 >>> import numpy as np 407 >>> def generate_2_columns(n): 408 ... for i in range(n): 409 ... yield (np.array([i]), np.array([j for j in range(i + 1)])) 410 >>> 411 >>> column_names = ["col1", "col2"] 412 >>> dataset = ds.GeneratorDataset(generate_2_columns(8), column_names) 413 >>> bucket_boundaries = [5, 10] 414 >>> bucket_batch_sizes = [2, 1, 1] 415 >>> element_length_function = (lambda col1, col2: max(len(col1), len(col2))) 416 >>> # Will pad col2 to shape [bucket_boundaries[i]] where i is the 417 >>> # index of the bucket that is currently being batched. 418 >>> pad_info = {"col2": ([None], -1)} 419 >>> pad_to_bucket_boundary = True 420 >>> dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, 421 ... bucket_batch_sizes, 422 ... element_length_function, pad_info, 423 ... pad_to_bucket_boundary) 424 """ 425 return BucketBatchByLengthDataset(self, column_names, bucket_boundaries, bucket_batch_sizes, 426 element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder) 427 428 @check_batch 429 def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, 430 input_columns=None, output_columns=None, column_order=None, pad_info=None, python_multiprocessing=False, 431 max_rowsize=16): 432 """ 433 Combine batch_size number of consecutive rows into batches. 434 435 For any child node, a batch is treated as a single row. 436 For any column, all the elements within that column must have the same shape. 437 If a per_batch_map callable is provided, it will be applied to the batches of tensors. 438 439 Note: 440 The order of using repeat and batch reflects the number of batches and per_batch_map. 441 It is recommended that the repeat operation applied after the batch operation finished. 442 443 Args: 444 batch_size (int or function): The number of rows each batch is created with. An 445 int or callable object which takes exactly 1 parameter, BatchInfo. 446 drop_remainder (bool, optional): Determines whether or not to drop the last block 447 whose data row number is less than batch size (default=False). If True, and if there are less 448 than batch_size rows available to make the last batch, then those rows will 449 be dropped and not propagated to the child node. 450 num_parallel_workers (int, optional): Number of workers(threads) to process the dataset in parallel 451 (default=None). 452 per_batch_map (callable, optional): Per batch map callable. A callable which takes 453 (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch 454 of Tensors on a given column. The number of lists should match with number of entries in input_columns. 455 The last parameter of the callable should always be a BatchInfo object. Per_batch_map should return 456 (list[Tensor], list[Tensor], ...). The length of each list in output should be same as the input. 457 output_columns is required if the number of output lists is different from input. 458 input_columns (Union[str, list[str]], optional): List of names of the input columns. The size of the list 459 should match with signature of per_batch_map callable (default=None). 460 output_columns (Union[str, list[str]], optional): List of names assigned to the columns 461 outputted by the last operation. This parameter is mandatory if len(input_columns) != 462 len(output_columns). The size of this list must match the number of output 463 columns of the last operation. (default=None, output columns will have the same 464 name as the input columns, i.e., the columns will be replaced). 465 column_order (Union[str, list[str]], optional): Specifies the list of all the columns you need in the whole 466 dataset. The parameter is required when len(input_column) != len(output_column). Caution: the list here 467 is not just the columns specified in parameter input_columns and output_columns. 468 pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)} 469 would pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0 470 (default=None). 471 python_multiprocessing (bool, optional): Parallelize Python function per_batch_map with multi-processing. 472 This option could be beneficial if the function is computational heavy (default=False). 473 max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy 474 data between processes. This is only used if python_multiprocessing is set to True (default 16 MB). 475 476 Returns: 477 BatchDataset, dataset batched. 478 479 Examples: 480 >>> # Create a dataset where every 100 rows are combined into a batch 481 >>> # and drops the last incomplete batch if there is one. 482 >>> dataset = dataset.batch(100, True) 483 >>> # resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25) 484 >>> def np_resize(col, batchInfo): 485 ... output = col.copy() 486 ... s = (batchInfo.get_batch_num() + 1) ** 2 487 ... index = 0 488 ... for c in col: 489 ... img = Image.fromarray(c.astype('uint8')).convert('RGB') 490 ... img = img.resize((s, s), Image.ANTIALIAS) 491 ... output[index] = np.array(img) 492 ... index += 1 493 ... return (output,) 494 >>> dataset = dataset.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize) 495 """ 496 return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns, 497 output_columns, column_order, pad_info, python_multiprocessing, max_rowsize) 498 499 @check_sync_wait 500 def sync_wait(self, condition_name, num_batch=1, callback=None): 501 """ 502 Add a blocking condition to the input Dataset. A synchronize action will be applied. 503 504 Args: 505 condition_name (str): The condition name that is used to toggle sending next row. 506 num_batch (int): the number of batches without blocking at the start of each epoch. 507 callback (function): The callback function that will be invoked when sync_update is called. 508 509 Returns: 510 SyncWaitDataset, dataset added a blocking condition. 511 512 Raises: 513 RuntimeError: If condition name already exists. 514 515 Examples: 516 >>> import numpy as np 517 >>> def gen(): 518 ... for i in range(100): 519 ... yield (np.array(i),) 520 >>> 521 >>> class Augment: 522 ... def __init__(self, loss): 523 ... self.loss = loss 524 ... 525 ... def preprocess(self, input_): 526 ... return input_ 527 ... 528 ... def update(self, data): 529 ... self.loss = data["loss"] 530 >>> 531 >>> batch_size = 4 532 >>> dataset = ds.GeneratorDataset(gen, column_names=["input"]) 533 >>> 534 >>> aug = Augment(0) 535 >>> dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) 536 >>> dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"]) 537 >>> dataset = dataset.batch(batch_size) 538 >>> count = 0 539 >>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 540 ... assert data["input"][0] == count 541 ... count += batch_size 542 ... data = {"loss": count} 543 ... dataset.sync_update(condition_name="policy", data=data) 544 """ 545 return SyncWaitDataset(self, condition_name, num_batch, callback) 546 547 @check_shuffle 548 def shuffle(self, buffer_size): 549 """ 550 Randomly shuffles the rows of this dataset using the following policy: 551 552 1. Make a shuffle buffer that contains the first buffer_size rows. 553 2. Randomly select an element from the shuffle buffer to be the next row 554 propagated to the child node. 555 3. Get the next row (if any) from the parent node and put it in the shuffle buffer. 556 4. Repeat steps 2 and 3 until there are no more rows left in the shuffle buffer. 557 558 A random seed can be provided to be used on the first epoch. In every subsequent 559 epoch, the seed is changed to a new one, randomly generated value. 560 561 Args: 562 buffer_size (int): The size of the buffer (must be larger than 1) for 563 shuffling. Setting buffer_size equal to the number of rows in the entire 564 dataset will result in a global shuffle. 565 566 Returns: 567 ShuffleDataset, dataset shuffled. 568 569 Raises: 570 RuntimeError: If exist sync operators before shuffle. 571 572 Examples: 573 >>> # dataset is an instance object of Dataset 574 >>> # Optionally set the seed for the first epoch 575 >>> ds.config.set_seed(58) 576 >>> # Create a shuffled dataset using a shuffle buffer of size 4 577 >>> dataset = dataset.shuffle(4) 578 """ 579 return ShuffleDataset(self, buffer_size) 580 581 def flat_map(self, func): 582 """ 583 Map `func` to each row in dataset and flatten the result. 584 585 The specified `func` is a function that must take one 'Ndarray' as input 586 and return a 'Dataset'. 587 588 Args: 589 func (function): A function that must take one 'Ndarray' as an argument and 590 return a 'Dataset'. 591 592 Returns: 593 Dataset, dataset applied by the function. 594 595 Examples: 596 >>> # use NumpySlicesDataset as an example 597 >>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]]) 598 >>> 599 >>> def flat_map_func(array): 600 ... # create a NumpySlicesDataset with the array 601 ... dataset = ds.NumpySlicesDataset(array) 602 ... # repeat the dataset twice 603 ... dataset = dataset.repeat(2) 604 ... return dataset 605 >>> 606 >>> dataset = dataset.flat_map(flat_map_func) 607 >>> # [[0, 1], [0, 1], [2, 3], [2, 3]] 608 609 Raises: 610 TypeError: If `func` is not a function. 611 TypeError: If `func` doesn't return a Dataset. 612 """ 613 dataset = None 614 if not hasattr(func, '__call__'): 615 logger.error("func must be a function.") 616 raise TypeError("func must be a function.") 617 618 for row_data in self.create_tuple_iterator(output_numpy=True): 619 if dataset is None: 620 dataset = func(row_data) 621 else: 622 dataset += func(row_data) 623 624 if not isinstance(dataset, Dataset): 625 logger.error("flat_map must return a Dataset object.") 626 raise TypeError("flat_map must return a Dataset object.") 627 return dataset 628 629 @check_map 630 def map(self, operations, input_columns=None, output_columns=None, column_order=None, 631 num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16): 632 """ 633 Apply each operation in operations to this dataset. 634 635 The order of operations is determined by the position of each operation in the operations parameter. 636 operations[0] will be applied first, then operations[1], then operations[2], etc. 637 638 Each operation will be passed one or more columns from the dataset as input, and zero or 639 more columns will be outputted. The first operation will be passed the columns specified 640 in input_columns as input. If there is more than one operator in operations, the outputted 641 columns of the previous operation are used as the input columns for the next operation. 642 The columns outputted by the very last operation will be assigned names specified by 643 output_columns. 644 645 Only the columns specified in column_order will be propagated to the child node. These 646 columns will be in the same order as specified in column_order. 647 648 Args: 649 operations (Union[list[TensorOp], list[functions]]): List of operations to be 650 applied on the dataset. Operations are applied in the order they appear in this list. 651 input_columns (Union[str, list[str]], optional): List of the names of the columns that will be passed to 652 the first operation as input. The size of this list must match the number of 653 input columns expected by the first operator. (default=None, the first 654 operation will be passed however many columns that are required, starting from 655 the first column). 656 output_columns (Union[str, list[str]], optional): List of names assigned to the columns outputted by 657 the last operation. This parameter is mandatory if len(input_columns) != 658 len(output_columns). The size of this list must match the number of output 659 columns of the last operation. (default=None, output columns will have the same 660 name as the input columns, i.e., the columns will be replaced). 661 column_order (list[str], optional): Specifies the list of all the columns you need in the whole 662 dataset. The parameter is required when len(input_column) != len(output_column). Caution: the list here 663 is not just the columns specified in parameter input_columns and output_columns. 664 num_parallel_workers (int, optional): Number of threads used to process the dataset in 665 parallel (default=None, the value from the configuration will be used). 666 python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker processes. This 667 option could be beneficial if the Python operation is computational heavy (default=False). 668 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 669 (default=None, which means no cache is used). 670 callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None). 671 max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy 672 data between processes. This is only used if python_multiprocessing is set to True (default 16 MB). 673 674 675 Returns: 676 MapDataset, dataset after mapping operation. 677 678 Examples: 679 >>> # dataset is an instance of Dataset which has 2 columns, "image" and "label". 680 >>> 681 >>> # Define two operations, where each operation accepts 1 input column and outputs 1 column. 682 >>> decode_op = c_vision.Decode(rgb=True) 683 >>> random_jitter_op = c_vision.RandomColorAdjust(brightness=(0.8, 0.8), contrast=(1, 1), 684 ... saturation=(1, 1), hue=(0, 0)) 685 >>> 686 >>> # 1) Simple map example. 687 >>> 688 >>> # Apply decode_op on column "image". This column will be replaced by the outputted 689 >>> # column of decode_op. Since column_order is not provided, both columns "image" 690 >>> # and "label" will be propagated to the child node in their original order. 691 >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"]) 692 >>> 693 >>> # Decode and rename column "image" to "decoded_image". 694 >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"], output_columns=["decoded_image"]) 695 >>> 696 >>> # Specify the order of the output columns. 697 >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"], 698 ... output_columns=None, column_order=["label", "image"]) 699 >>> 700 >>> # Rename column "image" to "decoded_image" and also specify the order of the output columns. 701 >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"], 702 ... output_columns=["decoded_image"], column_order=["label", "decoded_image"]) 703 >>> 704 >>> # Rename column "image" to "decoded_image" and keep only this column. 705 >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"], 706 ... output_columns=["decoded_image"], column_order=["decoded_image"]) 707 >>> 708 >>> # A simple example for mapping pyfunc. Renaming columns and specifying column order 709 >>> # work in the same way as the previous examples. 710 >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"]) 711 >>> dataset = dataset.map(operations=[(lambda x: x + 1)], input_columns=["data"]) 712 >>> 713 >>> # 2) Map example with more than one operation. 714 >>> 715 >>> # Create a dataset where the images are decoded, then randomly color jittered. 716 >>> # decode_op takes column "image" as input and outputs one column. The column 717 >>> # outputted by decode_op is passed as input to random_jitter_op. 718 >>> # random_jitter_op will output one column. Column "image" will be replaced by 719 >>> # the column outputted by random_jitter_op (the very last operation). All other 720 >>> # columns are unchanged. Since column_order is not specified, the order of the 721 >>> # columns will remain the same. 722 >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"]) 723 >>> 724 >>> # Rename the column outputted by random_jitter_op to "image_mapped". 725 >>> # Specifying column order works in the same way as examples in 1). 726 >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"], 727 ... output_columns=["image_mapped"]) 728 >>> 729 >>> # Map with multiple operations using pyfunc. Renaming columns and specifying column order 730 >>> # work in the same way as examples in 1). 731 >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"]) 732 >>> dataset = dataset.map(operations=[(lambda x: x * x), (lambda x: x - 1)], input_columns=["data"], 733 ... output_columns=["data_mapped"]) 734 >>> 735 >>> # 3) Example where number of input columns is not equal to number of output columns. 736 >>> 737 >>> # operations[0] is a lambda that takes 2 columns as input and outputs 3 columns. 738 >>> # operations[1] is a lambda that takes 3 columns as input and outputs 1 column. 739 >>> # operations[2] is a lambda that takes 1 column as input and outputs 4 columns. 740 >>> # 741 >>> # Note: The number of output columns of operation[i] must equal the number of 742 >>> # input columns of operation[i+1]. Otherwise, this map call will also result 743 >>> # in an error. 744 >>> operations = [(lambda x, y: (x, x + y, x + y + 1)), 745 ... (lambda x, y, z: x * y * z), 746 ... (lambda x: (x % 2, x % 3, x % 5, x % 7))] 747 >>> 748 >>> # Note: Since the number of input columns is not the same as the number of 749 >>> # output columns, the output_columns and column_order parameters must be 750 >>> # specified. Otherwise, this map call will also result in an error. 751 >>> 752 >>> dataset = ds.NumpySlicesDataset(data=([[0, 1, 2]], [[3, 4, 5]]), column_names=["x", "y"]) 753 >>> 754 >>> # Propagate all columns to the child node in this order: 755 >>> dataset = dataset.map(operations, input_columns=["x", "y"], 756 ... output_columns=["mod2", "mod3", "mod5", "mod7"], 757 ... column_order=["mod2", "mod3", "mod5", "mod7"]) 758 >>> 759 >>> # Propagate some columns to the child node in this order: 760 >>> dataset = dataset.map(operations, input_columns=["x", "y"], 761 ... output_columns=["mod2", "mod3", "mod5", "mod7"], 762 ... column_order=["mod7", "mod3", "col2"]) 763 """ 764 765 return MapDataset(self, operations, input_columns, output_columns, column_order, num_parallel_workers, 766 python_multiprocessing, cache, callbacks, max_rowsize) 767 768 @check_filter 769 def filter(self, predicate, input_columns=None, num_parallel_workers=None): 770 """ 771 Filter dataset by prediction. 772 773 Note: 774 If input_columns not provided or provided with empty, all columns will be used. 775 776 Args: 777 predicate (callable): Python callable which returns a boolean value. If False then filter the element. 778 input_columns (Union[str, list[str]], optional): List of names of the input columns, when 779 default=None, the predicate will be applied on all columns in the dataset. 780 num_parallel_workers (int, optional): Number of workers to process the dataset 781 in parallel (default=None). 782 783 Returns: 784 FilterDataset, dataset filtered. 785 786 Examples: 787 >>> # generator data(0 ~ 63) 788 >>> # filter the data that greater than or equal to 11 789 >>> dataset = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"]) 790 """ 791 return FilterDataset(self, predicate, input_columns, num_parallel_workers) 792 793 @check_repeat 794 def repeat(self, count=None): 795 """ 796 Repeat this dataset `count` times. Repeat infinitely if the count is None or -1. 797 798 Note: 799 The order of using repeat and batch reflects the number of batches. It is recommended that 800 the repeat operation is used after the batch operation. 801 802 Args: 803 count (int): Number of times the dataset is going to be repeated (default=None). 804 805 Returns: 806 RepeatDataset, dataset repeated. 807 808 Examples: 809 >>> # dataset is an instance object of Dataset 810 >>> 811 >>> # Create a dataset where the dataset is repeated for 50 epochs 812 >>> dataset = dataset.repeat(50) 813 >>> 814 >>> # Create a dataset where each epoch is shuffled individually 815 >>> dataset = dataset.shuffle(10) 816 >>> dataset = dataset.repeat(50) 817 >>> 818 >>> # Create a dataset where the dataset is first repeated for 819 >>> # 50 epochs before shuffling. The shuffle operator will treat 820 >>> # the entire 50 epochs as one big dataset. 821 >>> dataset = dataset.repeat(50) 822 >>> dataset = dataset.shuffle(10) 823 """ 824 return RepeatDataset(self, count) 825 826 @check_skip 827 def skip(self, count): 828 """ 829 Skip the first N elements of this dataset. 830 831 Args: 832 count (int): Number of elements in the dataset to be skipped. 833 834 Returns: 835 SkipDataset, dataset that containing rows like origin rows subtract skipped rows. 836 837 Examples: 838 >>> # dataset is an instance object of Dataset 839 >>> # Create a dataset which skips first 3 elements from data 840 >>> dataset = dataset.skip(3) 841 """ 842 return SkipDataset(self, count) 843 844 @check_take 845 def take(self, count=-1): 846 """ 847 Takes at most given numbers of elements from the dataset. 848 849 Note: 850 1. If count is greater than the number of elements in the dataset or equal to -1, 851 all the elements in dataset will be taken. 852 2. The order of using take and batch matters. If take is before batch operation, 853 then take given number of rows; otherwise take given number of batches. 854 855 Args: 856 count (int, optional): Number of elements to be taken from the dataset (default=-1). 857 858 Returns: 859 TakeDataset, dataset taken. 860 861 Examples: 862 >>> # dataset is an instance object of Dataset 863 >>> # Create a dataset where the dataset includes 50 elements. 864 >>> dataset = dataset.take(50) 865 """ 866 return TakeDataset(self, count) 867 868 def _get_absolute_split_sizes(self, sizes): 869 """ 870 Internal method called by split to calculate absolute split sizes and to 871 do some error checking after calculating absolute split sizes. 872 873 Returns: 874 int, absolute split sizes of the dataset. 875 """ 876 # Call get_dataset_size here and check input here because 877 # don't want to call this once in check_split and another time in 878 # here again 879 dataset_size = self.get_dataset_size() 880 881 if dataset_size is None or dataset_size <= 0: 882 raise RuntimeError("dataset_size is unknown, unable to split.") 883 884 if not isinstance(sizes, list): 885 raise RuntimeError("sizes must be a list.") 886 887 all_int = all(isinstance(item, int) for item in sizes) 888 if all_int: 889 sizes_sum = sum(sizes) 890 if sizes_sum != dataset_size: 891 raise RuntimeError("Sum of split sizes {} is not equal to dataset size {}." 892 .format(sizes_sum, dataset_size)) 893 return sizes 894 895 absolute_sizes = [] 896 for item in sizes: 897 absolute_size = int(round(item * dataset_size)) 898 if absolute_size == 0: 899 raise RuntimeError("Split percentage {} is too small.".format(item)) 900 absolute_sizes.append(absolute_size) 901 902 absolute_sizes_sum = sum(absolute_sizes) 903 904 # if we still need more rows, give them to the first split. 905 # if we have too many rows, remove the extras from the first split that has 906 # enough rows. 907 size_difference = int(dataset_size - absolute_sizes_sum) 908 if size_difference > 0: 909 absolute_sizes[0] += size_difference 910 else: 911 for i, _ in enumerate(absolute_sizes): 912 if absolute_sizes[i] + size_difference > 0: 913 absolute_sizes[i] += size_difference 914 break 915 916 if sum(absolute_sizes) != dataset_size: 917 raise RuntimeError("Sum of calculated split sizes {} is not equal to dataset size {}." 918 .format(absolute_sizes_sum, dataset_size)) 919 920 return absolute_sizes 921 922 @check_split 923 def split(self, sizes, randomize=True): 924 """ 925 Split the dataset into smaller, non-overlapping datasets. 926 927 This is a general purpose split function which can be called from any operator in the pipeline. 928 There is another, optimized split function, which will be called automatically if ds.split is 929 called where ds is a MappableDataset. 930 931 Args: 932 sizes (Union[list[int], list[float]]): If a list of integers [s1, s2, …, sn] is 933 provided, the dataset will be split into n datasets of size s1, size s2, …, size sn 934 respectively. If the sum of all input sizes does not equal the original dataset size, an 935 error will throw. 936 If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1 937 and must sum to 1, otherwise an error will throw. The dataset will be split into n 938 Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the 939 original dataset. 940 If after rounding: 941 942 - Any size equals 0, an error will occur. 943 - The sum of split sizes < K, the difference of K - sigma(round(fi * k)) will be added to the first 944 split. 945 - The sum of split sizes > K, the difference of sigma(round(fi * K)) - K will be removed from the first 946 large enough split such that it will have at least 1 row after removing the difference. 947 948 randomize (bool, optional): Determines whether or not to split the data randomly (default=True). 949 If True, the data will be randomly split. Otherwise, each split will be created with 950 consecutive rows from the dataset. 951 952 Note: 953 1. Dataset cannot be sharded if split is going to be called. 954 2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead. 955 Shuffling the dataset may not be deterministic, which means the data in each split 956 will be different in each epoch. 957 958 Raises: 959 RuntimeError: If get_dataset_size returns None or is not supported for this dataset. 960 RuntimeError: If `sizes` is list of integers and sum of all elements in sizes does not 961 equal the dataset size. 962 RuntimeError: If `sizes` is list of float and there is a split with size 0 after calculations. 963 RuntimeError: If the dataset is sharded prior to calling split. 964 ValueError: If `sizes` is list of float and not all floats are between 0 and 1, or if the 965 floats don't sum to 1. 966 967 Returns: 968 tuple(Dataset), a tuple of datasets that have been split. 969 970 Examples: 971 >>> # TextFileDataset is not a mappable dataset, so this non-optimized split will be called. 972 >>> # Since many datasets have shuffle on by default, set shuffle to False if split will be called! 973 >>> dataset = ds.TextFileDataset(text_file_dataset_dir, shuffle=False) 974 >>> train_dataset, test_dataset = dataset.split([0.9, 0.1]) 975 """ 976 if self.is_shuffled(): 977 logger.warning("Dataset is shuffled before split.") 978 979 if self.is_sharded(): 980 raise RuntimeError("Dataset should not be sharded before split.") 981 982 absolute_sizes = self._get_absolute_split_sizes(sizes) 983 splits = [] 984 rows_to_skip = 0 985 for size in absolute_sizes: 986 ds = copy.deepcopy(self) 987 if randomize: 988 # want to shuffle the same way every epoch before split 989 # in alter_tree, shuffle buffer is minimum 10000, so use 10000 here 990 ds = ds.shuffle(10000) 991 ds.reshuffle_each_epoch = False 992 993 if rows_to_skip > 0: 994 ds = ds.skip(rows_to_skip) 995 996 ds = ds.take(size) 997 splits.append(ds) 998 999 rows_to_skip += size 1000 1001 return tuple(splits) 1002 1003 @check_zip_dataset 1004 def zip(self, datasets): 1005 """ 1006 Zip the datasets in the sense of input tuple of datasets. Columns in the input datasets must have different 1007 name. 1008 1009 Args: 1010 datasets (Union[tuple, class Dataset]): A tuple of datasets or a single class Dataset 1011 to be zipped together with this dataset. 1012 1013 Returns: 1014 ZipDataset, dataset zipped. 1015 1016 Examples: 1017 >>> # Create a dataset which is the combination of dataset and dataset_1 1018 >>> dataset = dataset.zip(dataset_1) 1019 """ 1020 if isinstance(datasets, tuple): 1021 datasets = (self, *datasets) 1022 elif isinstance(datasets, Dataset): 1023 datasets = (self, datasets) 1024 else: 1025 raise TypeError("Invalid datasets, expected Dataset object or tuple of Dataset, but got %s!" % datasets) 1026 return ZipDataset(datasets) 1027 1028 @check_concat 1029 def concat(self, datasets): 1030 """ 1031 Concatenate the dataset objects in the input list. 1032 Performing "+" operation on dataset objects can achieve the same effect. 1033 1034 Note: 1035 The column name, and rank and type of the column data must be the same in the input datasets. 1036 1037 Args: 1038 datasets (Union[list, class Dataset]): A list of datasets or a single class Dataset 1039 to be concatenated together with this dataset. 1040 1041 Returns: 1042 ConcatDataset, dataset concatenated. 1043 1044 Examples: 1045 >>> # Create a dataset by concatenating dataset_1 and dataset_2 with "+" operator 1046 >>> dataset = dataset_1 + dataset_2 1047 >>> # Create a dataset by concatenating dataset_1 and dataset_2 with concat operation 1048 >>> dataset = dataset_1.concat(dataset_2) 1049 """ 1050 if isinstance(datasets, Dataset): 1051 datasets = [self] + [datasets] 1052 elif isinstance(datasets, list): 1053 datasets = [self] + datasets 1054 else: 1055 raise TypeError("Invalid datasets, expected Dataset object or list of Dataset, but got %s!" % datasets) 1056 return ConcatDataset(datasets) 1057 1058 @check_rename 1059 def rename(self, input_columns, output_columns): 1060 """ 1061 Rename the columns in input datasets. 1062 1063 Args: 1064 input_columns (Union[str, list[str]]): List of names of the input columns. 1065 output_columns (Union[str, list[str]]): List of names of the output columns. 1066 1067 Returns: 1068 RenameDataset, dataset renamed. 1069 1070 Examples: 1071 >>> # dataset is an instance object of Dataset 1072 >>> input_columns = ["input_col1", "input_col2", "input_col3"] 1073 >>> output_columns = ["output_col1", "output_col2", "output_col3"] 1074 >>> 1075 >>> # Create a dataset where input_col1 is renamed to output_col1, and 1076 >>> # input_col2 is renamed to output_col2, and input_col3 is renamed 1077 >>> # to output_col3. 1078 >>> dataset = dataset.rename(input_columns=input_columns, output_columns=output_columns) 1079 """ 1080 1081 return RenameDataset(self, input_columns, output_columns) 1082 1083 @check_project 1084 def project(self, columns): 1085 """ 1086 Project certain columns in input dataset. 1087 1088 The specified columns will be selected from the dataset and passed into 1089 the pipeline with the order specified. The other columns are discarded. 1090 1091 Args: 1092 columns(Union[str, list[str]]): List of names of the columns to project. 1093 1094 Returns: 1095 ProjectDataset, dataset projected. 1096 1097 Examples: 1098 >>> # dataset is an instance object of Dataset 1099 >>> columns_to_project = ["column3", "column1", "column2"] 1100 >>> 1101 >>> # Create a dataset that consists of column3, column1, column2 1102 >>> # in that order, regardless of the original order of columns. 1103 >>> dataset = dataset.project(columns=columns_to_project) 1104 """ 1105 1106 return ProjectDataset(self, columns) 1107 1108 def build_vocab(self, columns, freq_range, top_k, special_tokens, special_first): 1109 """ 1110 Function to create a Vocab from source dataset 1111 1112 Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab 1113 which contains top_k most frequent words (if top_k is specified) 1114 1115 Args: 1116 1117 columns(Union[str, list[str]]): Column names to get words from. 1118 freq_range(tuple[int]): A tuple of integers (min_frequency, max_frequency). Words within the frequency 1119 range will be stored. 1120 Naturally 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency 1121 can be set to default, which corresponds to 0/total_words separately. 1122 top_k(int): Number of words to be built into vocab. top_k most frequent words are 1123 taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken 1124 special_tokens(list[str]): A list of strings, each one is a special token. 1125 special_first(bool): Whether special_tokens will be prepended/appended to vocab, If special_tokens 1126 is specified and special_first is set to default, special_tokens will be prepended. 1127 1128 Returns: 1129 Vocab, vocab built from the dataset. 1130 1131 Examples: 1132 >>> import numpy as np 1133 >>> 1134 >>> def gen_corpus(): 1135 ... # key: word, value: number of occurrences, reason for using letters is so their order is apparent 1136 ... corpus = {"Z": 4, "Y": 4, "X": 4, "W": 3, "U": 3, "V": 2, "T": 1} 1137 ... for k, v in corpus.items(): 1138 ... yield (np.array([k] * v, dtype='S'),) 1139 >>> column_names = ["column1"] 1140 >>> dataset = ds.GeneratorDataset(gen_corpus, column_names) 1141 >>> dataset = dataset.build_vocab(columns=["column1"], 1142 ... freq_range=(1, 10), top_k=5, 1143 ... special_tokens=["<pad>", "<unk>"], 1144 ... special_first=True) 1145 1146 """ 1147 vocab = cde.Vocab() 1148 columns = replace_none(columns, []) 1149 if not isinstance(columns, list): 1150 columns = [columns] 1151 1152 freq_range = replace_none(freq_range, (0, 9223372036854775807)) 1153 if freq_range[0] is None: 1154 freq_range = (0, freq_range[1]) 1155 if freq_range[1] is None: 1156 freq_range = (freq_range[0], 9223372036854775807) 1157 special_tokens = replace_none(special_tokens, []) 1158 top_k = replace_none(top_k, 9223372036854775807) 1159 1160 ir_tree, api_tree = self.create_ir_tree() 1161 1162 # vocab node 1163 vocab_node = cde.BuildVocabNode(ir_tree, vocab, columns, freq_range, top_k, special_tokens, special_first) 1164 1165 runtime_context = cde.PythonRuntimeContext() 1166 runtime_context.Init() 1167 1168 # build vocab 1169 consumer = cde.PythonBuildVocabConsumer() 1170 consumer.Init(vocab_node) 1171 runtime_context.AssignConsumer(consumer) 1172 1173 consumer.Start() 1174 del api_tree 1175 1176 return vocab 1177 1178 def build_sentencepiece_vocab(self, columns, vocab_size, character_coverage, model_type, params): 1179 """ 1180 Function to create a SentencePieceVocab from source dataset 1181 1182 Build a SentencePieceVocab from a dataset. 1183 1184 Args: 1185 1186 columns(list[str]): Column names to get words from. 1187 vocab_size(int): Vocabulary size. 1188 character_coverage(int): Percentage of characters covered by the model, must be between 1189 0.98 and 1.0 Good defaults are: 0.9995 for languages with rich character sets like 1190 Japanese or Chinese character sets, and 1.0 for other languages with small character sets 1191 like English or Latin. 1192 model_type(SentencePieceModel): Model type. Choose from unigram (default), bpe, char, or word. 1193 The input sentence must be pretokenized when using word type. 1194 params(dict): Any extra optional parameters of sentencepiece library according to your raw data 1195 1196 Returns: 1197 SentencePieceVocab, vocab built from the dataset. 1198 1199 Examples: 1200 >>> from mindspore.dataset.text import SentencePieceModel 1201 >>> 1202 >>> def gen_corpus(): 1203 ... # key: word, value: number of occurrences, reason for using letters is so their order is apparent 1204 ... corpus = {"Z": 4, "Y": 4, "X": 4, "W": 3, "U": 3, "V": 2, "T": 1} 1205 ... for k, v in corpus.items(): 1206 ... yield (np.array([k] * v, dtype='S'),) 1207 >>> column_names = ["column1","column2","column3"] 1208 >>> dataset = ds.GeneratorDataset(gen_corpus, column_names) 1209 >>> dataset = dataset.build_sentencepiece_vocab(columns=["column3", "column1", "column2"], 1210 ... vocab_size=5000, 1211 ... character_coverage=0.9995, 1212 ... model_type=SentencePieceModel.UNIGRAM, 1213 ... params={}) 1214 """ 1215 vocab = cde.SentencePieceVocab() 1216 1217 ir_tree, api_tree = self.create_ir_tree() 1218 1219 # vocab node 1220 vocab_node = cde.BuildSentenceVocabNode(ir_tree, vocab, columns, vocab_size, character_coverage, model_type, 1221 params) 1222 1223 runtime_context = cde.PythonRuntimeContext() 1224 runtime_context.Init() 1225 1226 # build vocab 1227 consumer = cde.PythonBuildVocabConsumer() 1228 consumer.Init(vocab_node) 1229 runtime_context.AssignConsumer(consumer) 1230 1231 consumer.Start() 1232 del api_tree 1233 1234 return vocab 1235 1236 def apply(self, apply_func): 1237 """ 1238 Apply a function in this dataset. 1239 1240 Args: 1241 apply_func (function): A function that must take one 'Dataset' as an argument and 1242 return a preprocessed 'Dataset'. 1243 1244 Returns: 1245 Dataset, dataset applied by the function. 1246 1247 Examples: 1248 >>> # dataset is an instance object of Dataset 1249 >>> 1250 >>> # Declare an apply_func function which returns a Dataset object 1251 >>> def apply_func(data): 1252 ... data = data.batch(2) 1253 ... return data 1254 >>> 1255 >>> # Use apply to call apply_func 1256 >>> dataset = dataset.apply(apply_func) 1257 1258 Raises: 1259 TypeError: If apply_func is not a function. 1260 TypeError: If apply_func doesn't return a Dataset. 1261 """ 1262 1263 if not hasattr(apply_func, '__call__'): 1264 raise TypeError("apply_func must be a function.") 1265 1266 dataset = apply_func(self) 1267 if not isinstance(dataset, Dataset): 1268 raise TypeError("apply_func must return a dataset.") 1269 return dataset 1270 1271 @check_device_send 1272 def device_que(self, send_epoch_end=True, create_data_info_queue=False): 1273 """ 1274 Return a transferred Dataset that transfers data through a device. 1275 1276 Args: 1277 send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True). 1278 create_data_info_queue (bool, optional): Whether to create queue which stores 1279 types and shapes of data or not(default=False). 1280 1281 Note: 1282 If device is Ascend, features of data will be transferred one by one. The limitation 1283 of data transmission per time is 256M. 1284 1285 Returns: 1286 TransferDataset, dataset for transferring. 1287 """ 1288 return self.to_device(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue) 1289 1290 @check_device_send 1291 def to_device(self, send_epoch_end=True, create_data_info_queue=False): 1292 """ 1293 Transfer data from CPU to GPU or Ascend or other devices. 1294 1295 Args: 1296 send_epoch_end (bool, optional): Whether to send the end of sequence to device or not (default=True). 1297 create_data_info_queue (bool, optional): Whether to create queue which stores 1298 types and shapes of data or not(default=False). 1299 1300 Note: 1301 If device is Ascend, features of data will be transferred one by one. The limitation 1302 of data transmission per second is 256M. 1303 1304 Returns: 1305 TransferDataset, dataset for transferring. 1306 1307 Raises: 1308 RuntimeError: If distribution file path is given but failed to read. 1309 """ 1310 return TransferDataset(self, send_epoch_end, create_data_info_queue) 1311 1312 @check_save 1313 def save(self, file_name, num_files=1, file_type='mindrecord'): 1314 """ 1315 Save the dynamic data processed by the dataset pipeline in common dataset format. 1316 Supported dataset formats: 'mindrecord' only 1317 1318 Implicit type casting exists when saving data as 'mindrecord'. The transform table shows how to do type casting. 1319 1320 .. list-table:: Implicit Type Casting when Saving as 'mindrecord' 1321 :widths: 25 25 50 1322 :header-rows: 1 1323 1324 * - Type in 'dataset' 1325 - Type in 'mindrecord' 1326 - Details 1327 * - bool 1328 - None 1329 - Not supported 1330 * - int8 1331 - int32 1332 - 1333 * - uint8 1334 - bytes(1D uint8) 1335 - Drop dimension 1336 * - int16 1337 - int32 1338 - 1339 * - uint16 1340 - int32 1341 - 1342 * - int32 1343 - int32 1344 - 1345 * - uint32 1346 - int64 1347 - 1348 * - int64 1349 - int64 1350 - 1351 * - uint64 1352 - None 1353 - Not supported 1354 * - float16 1355 - float32 1356 - 1357 * - float32 1358 - float32 1359 - 1360 * - float64 1361 - float64 1362 - 1363 * - string 1364 - string 1365 - Multi-dimensional string not supported 1366 1367 Note: 1368 1. To save the samples in order, set dataset's shuffle to False and num_files to 1. 1369 2. Before calling the function, do not use batch operator, repeat operator or data augmentation operators 1370 with random attribute in map operator. 1371 3. When array dimension is variable, one-dimensional arrays or 1372 multi-dimensional arrays with variable dimension 0 are supported. 1373 4. Mindrecord does not support DE_UINT64, multi-dimensional DE_UINT8(drop dimension) nor 1374 multi-dimensional DE_STRING. 1375 1376 Args: 1377 file_name (str): Path to dataset file. 1378 num_files (int, optional): Number of dataset files (default=1). 1379 file_type (str, optional): Dataset format (default='mindrecord'). 1380 1381 """ 1382 ir_tree, api_tree = self.create_ir_tree() 1383 1384 runtime_context = cde.PythonRuntimeContext() 1385 runtime_context.Init() 1386 consumer = cde.PythonSaveToDisk(file_name, num_files, file_type) 1387 consumer.Init(ir_tree) 1388 runtime_context.AssignConsumer(consumer) 1389 1390 consumer.Save() 1391 _set_dataset_permissions(file_name, num_files) 1392 del api_tree 1393 1394 @check_tuple_iterator 1395 def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True): 1396 """ 1397 Create an iterator over the dataset. The datatype retrieved back will be a list of ndarrays. 1398 1399 To specify which columns to list and the order needed, use columns_list. If columns_list 1400 is not provided, the order of the columns will remain unchanged. 1401 1402 Args: 1403 columns (list[str], optional): List of columns to be used to specify the order of columns 1404 (default=None, means all columns). 1405 num_epochs (int, optional): Maximum number of epochs that iterator can be iterated. 1406 (default=-1, iterator can be iterated infinite number of epochs) 1407 output_numpy (bool, optional): Whether or not to output NumPy datatype. 1408 If output_numpy=False, iterator will output MSTensor (default=False). 1409 do_copy (bool, optional): when output data type is mindspore.Tensor, 1410 use this param to select the conversion method, only take False for better performance (default=True). 1411 1412 Returns: 1413 TupleIterator, tuple iterator over the dataset. 1414 1415 Examples: 1416 >>> # dataset is an instance object of Dataset 1417 >>> iterator = dataset.create_tuple_iterator() 1418 >>> for item in iterator: 1419 ... # item is a list 1420 ... print(type(item)) 1421 ... break 1422 <class 'list'> 1423 """ 1424 if output_numpy is None: 1425 output_numpy = False 1426 1427 if Dataset._noop_mode(): 1428 return DummyIterator(self, 'tuple') 1429 return TupleIterator(self, columns, num_epochs, output_numpy, do_copy) 1430 1431 @check_dict_iterator 1432 def create_dict_iterator(self, num_epochs=-1, output_numpy=False): 1433 """ 1434 Create an iterator over the dataset. The data retrieved will be a dictionary datatype. 1435 1436 The order of the columns in the dictionary may not be the same as the original order. 1437 1438 Args: 1439 num_epochs (int, optional): Maximum number of epochs that iterator can be iterated 1440 (default=-1, iterator can be iterated infinite number of epochs). 1441 output_numpy (bool, optional): Whether or not to output NumPy datatype, 1442 if output_numpy=False, iterator will output MSTensor (default=False). 1443 1444 Returns: 1445 DictIterator, dictionary iterator over the dataset. 1446 1447 Examples: 1448 >>> # dataset is an instance object of Dataset 1449 >>> iterator = dataset.create_dict_iterator() 1450 >>> for item in iterator: 1451 ... # item is a dict 1452 ... print(type(item)) 1453 ... break 1454 <class 'dict'> 1455 """ 1456 if output_numpy is None: 1457 output_numpy = False 1458 1459 if Dataset._noop_mode(): 1460 return DummyIterator(self, 'dict') 1461 return DictIterator(self, num_epochs, output_numpy) 1462 1463 def __iter__(self): 1464 """Create an iterator over the dataset.""" 1465 return self.create_tuple_iterator(num_epochs=1) 1466 1467 @property 1468 def input_indexs(self): 1469 """ 1470 Get Input Index Information 1471 1472 Returns: 1473 tuple, tuple of the input index information. 1474 1475 Examples: 1476 >>> # dataset is an instance object of Dataset 1477 >>> # set input_indexs 1478 >>> dataset.input_indexs = 10 1479 >>> print(dataset.input_indexs) 1480 10 1481 """ 1482 if self._input_indexs != (): 1483 return self._input_indexs 1484 1485 # find input_indexes of children 1486 children_input_index = [child.input_indexs for child in self.children] 1487 1488 # in case of more than one child, return the first input_indexes 1489 for cix in children_input_index: 1490 if cix != (): 1491 return cix 1492 1493 # if all children's input_indexes are () or the node is a leaf 1494 return self._input_indexs 1495 1496 @input_indexs.setter 1497 def input_indexs(self, value): 1498 self._input_indexs = value 1499 1500 def copy_batch_size(self, value): 1501 self._batch_size = value 1502 1503 def _init_tree_getters(self): 1504 """ 1505 Get pipeline information. 1506 """ 1507 ir_tree, api_tree = self.create_ir_tree() 1508 1509 runtime_context = cde.PythonRuntimeContext() 1510 runtime_context.Init() 1511 getter = cde.TreeGetters() 1512 getter.Init(ir_tree) 1513 runtime_context.AssignConsumer(getter) 1514 return getter, runtime_context, api_tree 1515 1516 def __init_size_getter(self): 1517 """ 1518 Get pipeline information. 1519 """ 1520 ir_tree, api_tree = self.create_ir_tree() 1521 1522 runtime_context = cde.PythonRuntimeContext() 1523 runtime_context.Init() 1524 getter = cde.DatasetSizeGetters() 1525 getter.Init(ir_tree) 1526 runtime_context.AssignConsumer(getter) 1527 return getter, runtime_context, api_tree 1528 1529 def get_col_names(self): 1530 """ 1531 Return the names of the columns in dataset. 1532 1533 Returns: 1534 list, list of column names in the dataset. 1535 1536 Examples: 1537 >>> # dataset is an instance object of Dataset 1538 >>> col_names = dataset.get_col_names() 1539 """ 1540 if self._col_names is None: 1541 runtime_getter = self._init_tree_getters() 1542 self._col_names = runtime_getter[0].GetColumnNames() 1543 self.close_pool() 1544 runtime_getter[2].notify_watchdog() 1545 return self._col_names 1546 1547 def output_shapes(self): 1548 """ 1549 Get the shapes of output data. 1550 1551 Returns: 1552 list, list of shapes of each column. 1553 1554 Examples: 1555 >>> # dataset is an instance object of Dataset 1556 >>> output_shapes = dataset.output_shapes() 1557 """ 1558 if self.saved_output_shapes is None: 1559 runtime_getter = self._init_tree_getters() 1560 self.saved_output_shapes = runtime_getter[0].GetOutputShapes() 1561 self.saved_output_types = runtime_getter[0].GetOutputTypes() 1562 self.close_pool() 1563 runtime_getter[2].notify_watchdog() 1564 if self.dynamic_setting[0]: 1565 self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes = self._dynamic_output_shapes() 1566 return self.saved_output_shapes 1567 1568 def output_types(self): 1569 """ 1570 Get the types of output data. 1571 1572 Returns: 1573 list, list of data types. 1574 1575 Examples: 1576 >>> # dataset is an instance object of Dataset 1577 >>> output_types = dataset.output_types() 1578 """ 1579 if self.saved_output_types is None: 1580 runtime_getter = self._init_tree_getters() 1581 self.saved_output_shapes = runtime_getter[0].GetOutputShapes() 1582 self.saved_output_types = runtime_getter[0].GetOutputTypes() 1583 self.close_pool() 1584 runtime_getter[2].notify_watchdog() 1585 if self.dynamic_setting[0]: 1586 self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes = self._dynamic_output_shapes() 1587 return self.saved_output_types 1588 1589 def get_dataset_size(self): 1590 """ 1591 Return the number of batches in an epoch. 1592 1593 Returns: 1594 int, number of batches. 1595 1596 Examples: 1597 >>> # dataset is an instance object of Dataset 1598 >>> dataset_size = dataset.get_dataset_size() 1599 """ 1600 if self.dataset_size is None: 1601 runtime_getter = self.__init_size_getter() 1602 self.dataset_size = runtime_getter[0].GetDatasetSize(False) 1603 self.close_pool() 1604 runtime_getter[2].notify_watchdog() 1605 return self.dataset_size 1606 1607 def set_dynamic_columns(self, columns=None): 1608 """ 1609 Set dynamic shape information of source data, it should be set after the pipeline is defined. 1610 1611 Args: 1612 columns (dict): A dict contains shape information of each column in dataset. 1613 The value of shape[i] is :py:obj:`None` indicates that the data length of shape[i] is dynamic. 1614 1615 Examples: 1616 >>> import numpy as np 1617 >>> 1618 >>> def generator1(): 1619 >>> for i in range(1, 100): 1620 >>> yield np.ones((16, i, 83)), np.array(i) 1621 >>> 1622 >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"]) 1623 >>> dataset.set_dynamic_columns(columns={"data1": [16, None, 83], "data2": []}) 1624 """ 1625 if not isinstance(columns, dict): 1626 raise TypeError("Pass a dict to set dynamic shape, example: {\"data1\": [16, None, 256]}") 1627 self.dynamic_setting[0] = True 1628 self.dynamic_setting[1] = columns 1629 1630 def dynamic_min_max_shapes(self): 1631 """ 1632 Get minimum and maximum data length of dynamic source data, for dynamic graph compilation. 1633 1634 Returns: 1635 lists, min_shapes, max_shapes of source data. 1636 1637 Examples: 1638 >>> import numpy as np 1639 >>> 1640 >>> def generator1(): 1641 >>> for i in range(1, 100): 1642 >>> yield np.ones((16, i, 83)), np.array(i) 1643 >>> 1644 >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"]) 1645 >>> dataset.set_dynamic_columns(columns={"data1": [16, None, 83], "data2": []}) 1646 >>> min_shapes, max_shapes = dataset.dynamic_min_max_shapes() 1647 """ 1648 if self.saved_min_shapes is None or self.saved_max_shapes is None: 1649 self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes = self._dynamic_output_shapes() 1650 return self.saved_min_shapes, self.saved_max_shapes 1651 1652 def _dynamic_output_shapes(self): 1653 """ 1654 Get dynamic information of source data. 1655 1656 Returns: 1657 lists, dynamic_shapes, min_shapes, max_shapes of source data. 1658 """ 1659 if not self.dynamic_setting[1]: 1660 raise RuntimeError("dynamic_columns is not set, call set_dynamic_columns() by final Dataset Op.") 1661 1662 if self.saved_output_shapes is not None and self.saved_min_shapes is not None and \ 1663 self.saved_max_shapes is not None: 1664 return self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes 1665 1666 logger.warning("Calculating dynamic shape of input data, this will take a few minutes...") 1667 # Assume data1 shape is dynamic, data2 shape is fix 1668 # {"data1": [batch_size, None, feat_len], "data2": [batch_size, feat_len]} 1669 dynamic_columns = self.dynamic_setting[1] 1670 # ["data1", "data2"] 1671 dataset_columns = self.get_col_names() 1672 for column in dynamic_columns: 1673 if column not in dataset_columns: 1674 raise RuntimeError("dynamic column [" + column + "] does not match any column in dataset: " + 1675 str(dataset_columns)) 1676 1677 # Shape[1] of data1 is variable 1678 # {"data1": {(batch_size, 100, feat_len), (16, 200, 83)}, "data2": {(batch_size, feat_len)}} 1679 column_shape_set = {col: set() for col in dataset_columns} 1680 dataset_size_counter = 0 1681 for data in self.create_dict_iterator(num_epochs=1, output_numpy=True): 1682 dataset_size_counter += 1 1683 for col in data.keys(): 1684 if col in dynamic_columns: 1685 shape_mismatch = "dynamic column [" + col + "] with shape " + str(dynamic_columns[col]) + \ 1686 " does not match dataset column [" + col + "] with shape " + str(list(data[col].shape)) 1687 if data[col].ndim != len(dynamic_columns[col]): 1688 raise RuntimeError(shape_mismatch) 1689 for dim in range(len(dynamic_columns[col])): 1690 if dynamic_columns[col][dim] is not None and dynamic_columns[col][dim] != data[col].shape[dim]: 1691 raise RuntimeError(shape_mismatch) 1692 column_shape_set[col].add(tuple(data[col].shape)) 1693 1694 # we get dataset_size after dryrun 1695 self.dataset_size = dataset_size_counter 1696 1697 min_shapes, max_shapes, dynamic_shapes = list(), list(), list() 1698 for col, shape_set in column_shape_set.items(): 1699 if len(shape_set) > 1: 1700 if col not in dynamic_columns: 1701 raise RuntimeError("column [" + col + "] has dynamic shape but not set by set_dynamic_columns()" + 1702 ", shapes of [" + col + "]: " + str(list(shape_set))) 1703 shape_npy = np.array(list(shape_set)) 1704 max_shape = shape_npy.max(axis=0) 1705 min_shape = shape_npy.min(axis=0) 1706 1707 # Set min shape to 1 due to unknown shuffle 1708 min_shape = np.where(np.equal(dynamic_columns[col], None), 1, min_shape) 1709 # Set dynamic dim to -1 for ME 1710 dynamic_shape = np.where(np.equal(dynamic_columns[col], None), -1, dynamic_columns[col]) 1711 1712 max_shapes.append(max_shape.tolist()) 1713 min_shapes.append(min_shape.tolist()) 1714 dynamic_shapes.append(dynamic_shape.tolist()) 1715 else: 1716 # Also append fix shape to keep order of column shape 1717 fix_shape = list(list(shape_set)[0]) 1718 max_shapes.append(fix_shape) 1719 min_shapes.append(fix_shape) 1720 dynamic_shapes.append(fix_shape) 1721 if col in dynamic_columns: 1722 logger.warning("column [" + col + "] has no dynamic shape but set by set_dynamic_columns()") 1723 # Set min shape to 1 due to unknown shuffle 1724 min_shapes[-1] = np.where(np.equal(dynamic_columns[col], None), 1, fix_shape).tolist() 1725 # Set dynamic dim to -1 for ME 1726 dynamic_shapes[-1] = np.where(np.equal(dynamic_columns[col], None), -1, fix_shape).tolist() 1727 return dynamic_shapes, min_shapes, max_shapes 1728 1729 def num_classes(self): 1730 """ 1731 Get the number of classes in a dataset. 1732 1733 Returns: 1734 int, number of classes. 1735 1736 Examples: 1737 >>> # dataset is an instance object of Dataset 1738 >>> num_classes = dataset.num_classes() 1739 """ 1740 if self._num_classes is None: 1741 runtime_getter = self._init_tree_getters() 1742 self._num_classes = runtime_getter[0].GetNumClasses() 1743 self.close_pool() 1744 runtime_getter[2].notify_watchdog() 1745 if self._num_classes == -1: 1746 return None 1747 return self._num_classes 1748 1749 def get_sync_notifiers(self): 1750 if self.children: 1751 return self.children[0].get_sync_notifiers() 1752 return {} 1753 1754 def disable_sync(self): 1755 if self.children: 1756 return self.children[0].disable_sync() 1757 return {} 1758 1759 def is_sync(self): 1760 if self.children: 1761 return self.children[0].is_sync() 1762 return False 1763 1764 def sync_update(self, condition_name, num_batch=None, data=None): 1765 """ 1766 Release a blocking condition and trigger callback with given data. 1767 1768 Args: 1769 condition_name (str): The condition name that is used to toggle sending next row. 1770 num_batch (Union[int, None]): The number of batches (rows) that are released. 1771 When num_batch is None, it will default to the number specified by the 1772 sync_wait operator (default=None). 1773 data (Any): The data passed to the callback, user defined (default=None). 1774 """ 1775 if (not isinstance(num_batch, int) and num_batch is not None) or \ 1776 (isinstance(num_batch, int) and num_batch <= 0): 1777 # throwing exception, disable all sync_wait in pipeline 1778 self.disable_sync() 1779 raise RuntimeError("Sync_update batch size can only be positive integer, got : {}.".format(num_batch)) 1780 notifiers_dict = self.get_sync_notifiers() 1781 if not isinstance(condition_name, str): 1782 raise TypeError("Argument condition_name with value {} is not of type str, but got {}." 1783 .format(condition_name, type(condition_name))) 1784 if condition_name not in notifiers_dict: 1785 # throwing exception, disable all sync_wait in pipeline 1786 self.disable_sync() 1787 raise RuntimeError("Condition name not found.") 1788 if num_batch is not None: 1789 num_batch *= self.get_batch_size() 1790 notifiers_dict[condition_name](num_batch, data) 1791 1792 def get_batch_size(self): 1793 """ 1794 Return the size of batch. 1795 1796 Returns: 1797 int, the number of data in a batch. 1798 1799 Examples: 1800 >>> # dataset is an instance object of Dataset 1801 >>> batch_size = dataset.get_batch_size() 1802 """ 1803 if self._batch_size is None: 1804 runtime_getter = self._init_tree_getters() 1805 self._batch_size = runtime_getter[0].GetBatchSize() 1806 if self._batch_size is None: 1807 self._batch_size = 1 1808 return self._batch_size 1809 1810 def get_repeat_count(self): 1811 """ 1812 Get the replication times in RepeatDataset (default is 1). 1813 1814 Returns: 1815 int, the count of repeat. 1816 1817 Examples: 1818 >>> # dataset is an instance object of Dataset 1819 >>> repeat_count = dataset.get_repeat_count() 1820 """ 1821 if self._repeat_count is None: 1822 runtime_getter = self._init_tree_getters() 1823 self._repeat_count = runtime_getter[0].GetRepeatCount() 1824 if self._repeat_count is None: 1825 self._repeat_count = 1 1826 return self._repeat_count 1827 1828 def get_class_indexing(self): 1829 """ 1830 Return the class index. 1831 1832 Returns: 1833 dict, a str-to-int mapping from label name to index. 1834 dict, a str-to-list<int> mapping from label name to index for Coco ONLY. The second number 1835 in the list is used to indicate the super category. 1836 1837 Examples: 1838 >>> # dataset is an instance object of Dataset 1839 >>> class_indexing = dataset.get_class_indexing() 1840 """ 1841 if self.children: 1842 return self.children[0].get_class_indexing() 1843 return {} 1844 1845 def reset(self): 1846 """Reset the dataset for next epoch.""" 1847 1848 def is_shuffled(self): 1849 """Returns True if the dataset or its children is shuffled.""" 1850 for input_dataset in self.children: 1851 if input_dataset.is_shuffled(): 1852 return True 1853 1854 return False 1855 1856 def is_sharded(self): 1857 """Returns True if the dataset or its children is sharded.""" 1858 for input_dataset in self.children: 1859 if input_dataset.is_sharded(): 1860 return True 1861 1862 return False 1863 1864 def parse(self, children=None): 1865 raise NotImplementedError("Dataset has to implement parse method.") 1866 1867 def post_parse(self, ir_node): 1868 if self.cache: 1869 ir_node = ir_node.set_cache_client(self.cache.cache_client) 1870 if self.num_parallel_workers: 1871 ir_node = ir_node.set_num_workers(self.num_parallel_workers) 1872 1873 return ir_node 1874 1875 1876class SourceDataset(Dataset): 1877 """ 1878 Abstract class to represent a source dataset which produces content to the data pipeline. 1879 """ 1880 1881 def __init__(self, num_parallel_workers=None, num_samples=None, shuffle=True, num_shards=None, shard_id=None, 1882 cache=None): 1883 super().__init__(num_parallel_workers=num_parallel_workers, cache=cache) 1884 self.num_samples = replace_none(num_samples, 0) 1885 self.num_shards = replace_none(num_shards, 1) 1886 self.shard_id = replace_none(shard_id, 0) 1887 1888 if shuffle is not None and not isinstance(shuffle, (bool, Shuffle)): 1889 raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like 'Shuffle.GLOBAL' or " 1890 "'Shuffle.FILES' or 'Shuffle.INFILE'.") 1891 1892 self.shuffle_flag = 2 # Global shuffle 1893 if not isinstance(shuffle, Shuffle): 1894 if shuffle is None or shuffle: 1895 self.shuffle_flag = 2 # Global shuffle 1896 else: 1897 self.shuffle_flag = 0 # No shuffle 1898 else: 1899 if shuffle == Shuffle.GLOBAL: 1900 self.shuffle_flag = 2 # Global shuffle 1901 elif shuffle == Shuffle.FILES: 1902 self.shuffle_flag = 1 # Files shuffle 1903 elif shuffle == Shuffle.INFILE: 1904 self.shuffle_flag = 3 # Infile shuffle 1905 1906 def parse(self, children=None): 1907 raise NotImplementedError("Dataset has to implement parse method.") 1908 1909 @staticmethod 1910 def _find_files(patterns): 1911 """ 1912 Utility function to search for files with the given glob patterns. 1913 1914 Args: 1915 patterns (Union[str, list[str]]): String or list of patterns to be searched. 1916 1917 Returns: 1918 list, list of files. 1919 """ 1920 1921 if not isinstance(patterns, list): 1922 patterns = [patterns] 1923 1924 file_list = [] 1925 unmatched_patterns = [] 1926 for pattern in patterns: 1927 matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)] 1928 1929 if matches: 1930 file_list.extend(matches) 1931 else: 1932 unmatched_patterns.append(pattern) 1933 1934 if unmatched_patterns: 1935 raise ValueError("The following patterns did not match any files: {}.".format(unmatched_patterns)) 1936 1937 if file_list: # not empty 1938 return file_list 1939 raise ValueError("The list of path names matching the patterns is empty.") 1940 1941 def is_shuffled(self): 1942 return self.shuffle_flag > 0 1943 1944 def is_sharded(self): 1945 if self.num_shards is not None: 1946 return self.num_shards > 1 1947 return False 1948 1949 1950class MappableDataset(SourceDataset): 1951 """ 1952 Abstract class to represent a source dataset which supports use of samplers. 1953 """ 1954 1955 def parse(self, children=None): 1956 raise NotImplementedError("Dataset has to implement parse method.") 1957 1958 def __init__(self, num_parallel_workers=None, sampler=None, num_samples=None, shuffle=None, num_shards=None, 1959 shard_id=None, cache=None): 1960 super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, 1961 num_shards=num_shards, shard_id=shard_id, cache=cache) 1962 self.shuffle_flag = replace_none(shuffle, True) 1963 self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) 1964 1965 def add_sampler(self, new_sampler): 1966 """ 1967 Add a sampler for current dataset,. 1968 1969 Args: 1970 new_sampler (Sampler): The sampler to be added as the parent sampler for current dataset. 1971 1972 Examples: 1973 >>> # dataset is an instance object of Dataset 1974 >>> # use a DistributedSampler instead 1975 >>> new_sampler = ds.DistributedSampler(10, 2) 1976 >>> dataset.add_sampler(new_sampler) 1977 """ 1978 # note: By adding a sampler, the sampled IDs will flow to new_sampler 1979 # after first passing through the current samplers attached to this dataset. 1980 self.dataset_size = None 1981 new_sampler.add_child(self.sampler) 1982 self.sampler = new_sampler 1983 1984 def use_sampler(self, new_sampler): 1985 """ 1986 Make the current dataset use the new_sampler provided by other API. 1987 1988 Args: 1989 new_sampler (Sampler): The sampler to use for the current dataset. 1990 1991 Examples: 1992 >>> # dataset is an instance object of Dataset 1993 >>> # use a DistributedSampler instead 1994 >>> new_sampler = ds.DistributedSampler(10, 2) 1995 >>> dataset.use_sampler(new_sampler) 1996 """ 1997 if new_sampler is None: 1998 raise TypeError("Input sampler can not be None.") 1999 if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): 2000 raise TypeError("Input sampler is not an instance of a sampler.") 2001 self.dataset_size = None 2002 2003 self.sampler = self.sampler.child_sampler 2004 self.add_sampler(new_sampler) 2005 2006 def is_shuffled(self): 2007 return self.sampler.is_shuffled() 2008 2009 def is_sharded(self): 2010 return self.sampler.is_sharded() 2011 2012 @check_split 2013 def split(self, sizes, randomize=True): 2014 """ 2015 Split the dataset into smaller, non-overlapping datasets. 2016 2017 Args: 2018 sizes (Union[list[int], list[float]]): If a list of integers [s1, s2, …, sn] is 2019 provided, the dataset will be split into n datasets of size s1, size s2, …, size sn 2020 respectively. If the sum of all sizes does not equal the original dataset size, an 2021 error will occur. 2022 If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1 2023 and must sum to 1, otherwise an error will occur. The dataset will be split into n 2024 Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the 2025 original dataset. 2026 If after rounding: 2027 2028 - Any size equals 0, an error will occur. 2029 - The sum of split sizes < K, the difference will be added to the first split. 2030 - The sum of split sizes > K, the difference will be removed from the first large 2031 enough split such that it will have at least 1 row after removing the difference. 2032 2033 randomize (bool, optional): Determines whether or not to split the data randomly (default=True). 2034 If True, the data will be randomly split. Otherwise, each split will be created with 2035 consecutive rows from the dataset. 2036 2037 Note: 2038 1. There is an optimized split function, which will be called automatically when the dataset 2039 that calls this function is a MappableDataset. 2040 2. Dataset should not be sharded if split is going to be called. Instead, create a 2041 DistributedSampler and specify a split to shard after splitting. If the dataset is 2042 sharded after a split, it is strongly recommended setting the same seed in each instance 2043 of execution, otherwise each shard may not be part of the same split (see Examples). 2044 3. It is strongly recommended to not shuffle the dataset, but use randomize=True instead. 2045 Shuffling the dataset may not be deterministic, which means the data in each split 2046 will be different in each epoch. Furthermore, if sharding occurs after split, each 2047 shard may not be part of the same split. 2048 2049 Raises: 2050 RuntimeError: If get_dataset_size returns None or is not supported for this dataset. 2051 RuntimeError: If `sizes` is list of integers and sum of all elements in sizes does not 2052 equal the dataset size. 2053 RuntimeError: If `sizes` is list of float and there is a split with size 0 after calculations. 2054 RuntimeError: If the dataset is sharded prior to calling split. 2055 ValueError: If `sizes` is list of float and not all floats are between 0 and 1, or if the 2056 floats don't sum to 1. 2057 2058 Returns: 2059 tuple(Dataset), a tuple of datasets that have been split. 2060 2061 Examples: 2062 >>> # Since many datasets have shuffle on by default, set shuffle to False if split will be called! 2063 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, shuffle=False) 2064 >>> 2065 >>> # Set the seed, and tell split to use this seed when randomizing. 2066 >>> # This is needed because sharding will be done later 2067 >>> ds.config.set_seed(58) 2068 >>> train_dataset, test_dataset = dataset.split([0.9, 0.1]) 2069 >>> 2070 >>> # To shard the train dataset, use a DistributedSampler 2071 >>> train_sampler = ds.DistributedSampler(10, 2) 2072 >>> train_dataset.use_sampler(train_sampler) 2073 """ 2074 if self.is_shuffled(): 2075 logger.warning("Dataset is shuffled before split.") 2076 2077 if self.is_sharded(): 2078 raise RuntimeError("Dataset should not be sharded before split.") 2079 2080 absolute_sizes = self._get_absolute_split_sizes(sizes) 2081 splits = [] 2082 current_split_start_index = 0 2083 for size in absolute_sizes: 2084 ds = copy.deepcopy(self) 2085 ds.dataset_size = None 2086 if randomize: 2087 # want to shuffle the same way every epoch before split, we are assuming 2088 # that the user will call set_seed 2089 random_sampler = samplers.RandomSampler() 2090 random_sampler.reshuffle_each_epoch = False 2091 ds.add_sampler(random_sampler) 2092 2093 subset_sampler = samplers.SequentialSampler(current_split_start_index, size) 2094 ds.add_sampler(subset_sampler) 2095 2096 # add sequential sampler, so that if user calls use_sampler, we will 2097 # get rid of the sequential sampler instead of something we need 2098 ds.add_sampler(samplers.SequentialSampler()) 2099 2100 splits.append(ds) 2101 2102 current_split_start_index += size 2103 2104 return tuple(splits) 2105 2106 2107class BucketBatchByLengthDataset(Dataset): 2108 """ 2109 The result of applying BucketBatchByLength operator to the input dataset. 2110 """ 2111 2112 def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, 2113 pad_info, pad_to_bucket_boundary, drop_remainder): 2114 super().__init__(children=input_dataset) 2115 2116 self.column_names = to_list(column_names) 2117 self.bucket_boundaries = replace_none(bucket_boundaries, []) 2118 self.bucket_batch_sizes = replace_none(bucket_batch_sizes, []) 2119 self.element_length_function = element_length_function 2120 self.pad_info = replace_none(pad_info, {}) 2121 self.pad_to_bucket_boundary = replace_none(pad_to_bucket_boundary, False) 2122 self.drop_remainder = replace_none(drop_remainder, False) 2123 2124 def parse(self, children=None): 2125 return cde.BucketBatchByLengthNode(children[0], self.column_names, self.bucket_boundaries, 2126 self.bucket_batch_sizes, self.element_length_function, self.pad_info, 2127 self.pad_to_bucket_boundary, self.drop_remainder) 2128 2129 2130class BatchDataset(Dataset): 2131 """ 2132 The result of applying Batch operator to the input dataset. 2133 2134 Args: 2135 input_dataset (Dataset): Input Dataset to be batched. 2136 batch_size (Union[int, function]): The number of rows each batch is created with. An 2137 int or callable which takes exactly 1 parameter, BatchInfo. 2138 drop_remainder (bool, optional): Determines whether or not to drop the last 2139 possibly incomplete batch (default=False). If True, and if there are less 2140 than batch_size rows available to make the last batch, then those rows will 2141 be dropped and not propagated to the child node. 2142 num_parallel_workers (int, optional): Number of workers to process the dataset in parallel (default=None). 2143 per_batch_map (callable, optional): Per batch map callable. A callable which takes 2144 (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch of 2145 Tensors on a given column. The number of lists should match with number of entries in input_columns. The 2146 last parameter of the callable must always be a BatchInfo object. 2147 input_columns (Union[str, list[str]], optional): List of names of the input columns. The size of the list must 2148 match with signature of per_batch_map callable. 2149 output_columns (Union[str, list[str]], optional): List of names assigned to the columns outputted by 2150 the last operation. This parameter is mandatory if len(input_columns) != 2151 len(output_columns). The size of this list must match the number of output 2152 columns of the last operation. (default=None, output columns will have the same 2153 name as the input columns, i.e., the columns will be replaced). 2154 column_order (Union[str, list[str]], optional): Specifies the list of all the columns you need in the whole 2155 dataset. The parameter is required when len(input_column) != len(output_column). Caution: the list here 2156 is not just the columns specified in parameter input_columns and output_columns. 2157 pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)} 2158 will pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0. 2159 max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy 2160 data between processes. This is only used if python_multiprocessing is set to True (default 16 MB). 2161 2162 """ 2163 2164 def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, 2165 input_columns=None, output_columns=None, column_order=None, pad_info=None, 2166 python_multiprocessing=False, max_rowsize=16): 2167 super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers) 2168 2169 if BatchDataset._is_ancestor_of_repeat(input_dataset): 2170 logger.warning("Repeat is located before batch, data from two epochs can be batched together.") 2171 2172 BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) 2173 2174 # if batch_size is callable, set batch_size to 1 and batch_size_func to that callable function 2175 self.batch_size = batch_size if not callable(batch_size) else 1 2176 self.batch_size_func = None if not callable(batch_size) else batch_size 2177 2178 self.drop_remainder = replace_none(drop_remainder, False) 2179 2180 self.per_batch_map = per_batch_map 2181 2182 self.input_columns = to_list(input_columns) 2183 self.output_columns = to_list(output_columns) 2184 self.column_order = to_list(column_order) 2185 2186 self.pad = bool(pad_info is not None) 2187 self.pad_info = replace_none(pad_info, dict()) 2188 2189 self.python_multiprocessing = python_multiprocessing 2190 self.process_pool = None 2191 self.hook = None 2192 self.pids = [] 2193 self.eot = None 2194 self.watch_dog = None 2195 self.max_rowsize = max_rowsize 2196 2197 def parse(self, children=None): 2198 return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, self.pad, self.input_columns, 2199 self.output_columns, self.column_order, self.batch_size_func, self.per_batch_map, 2200 self.pad_info) 2201 2202 @staticmethod 2203 def _is_ancestor_of_repeat(dataset): 2204 """ 2205 Utility function to find the case where repeat is used before batch. 2206 2207 Args: 2208 dataset (Dataset): Dataset to be checked. 2209 2210 Returns: 2211 bool, whether repeat is used before batch. 2212 """ 2213 if isinstance(dataset, RepeatDataset): 2214 return True 2215 flag = False 2216 for input_dataset in dataset.children: 2217 flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset) 2218 return flag 2219 2220 @staticmethod 2221 def _update_batch_size_for_syncwait(dataset, batch_size): 2222 """ 2223 Utility function to notify batch size to sync_wait. 2224 2225 Args: 2226 dataset (Dataset): Dataset to be checked. 2227 batch_size (int): batch size to notify. 2228 """ 2229 if isinstance(dataset, SyncWaitDataset): 2230 dataset.update_sync_batch_size(batch_size) 2231 for input_dataset in dataset.children: 2232 BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) 2233 2234 def __deepcopy__(self, memodict): 2235 return self.__safe_deepcopy__(memodict, exclude=("per_batch_map", "batch_size_func", "__transfer_dataset__")) 2236 2237 # Iterator bootstrap will be called on iterator construction. 2238 # A deep copy of Dataset object is created prior of iterator_bootstrap. 2239 # This method will create per iterator process pool and bind pyfunc execution to the pool. 2240 def iterator_bootstrap(self): 2241 """ 2242 Per iterator bootstrap callback. 2243 """ 2244 if self.python_multiprocessing: 2245 if self.per_batch_map is None: 2246 logger.warning("per_batch_map is None so python_multiprocessing does not work.") 2247 return 2248 arg_q_list = [] 2249 res_q_list = [] 2250 2251 # If user didn't specify num_parallel_workers, set it to default 2252 if self.num_parallel_workers is not None: 2253 num_parallel = self.num_parallel_workers 2254 else: 2255 num_parallel = get_num_parallel_workers() 2256 2257 if get_enable_shared_mem(): 2258 _check_shm_usage(num_parallel, 1, self.max_rowsize * self.batch_size, 2) 2259 for _ in range(num_parallel): 2260 arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size)) 2261 res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size)) 2262 2263 # Construct pool with the callable list 2264 # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses 2265 self.process_pool = multiprocessing.Pool(processes=num_parallel, 2266 initializer=_pyfunc_worker_init, 2267 initargs=([self.per_batch_map], arg_q_list, res_q_list)) 2268 2269 idx = 0 2270 global _OP_NAME, _OP_PROCESS, _LOCK 2271 op_id = _OP_NAME[str(self)] 2272 process_id = {op_id: [self.num_parallel_workers, set()]} 2273 # obtain process id from multiprocessing.pool 2274 for pool in self.process_pool._pool: # pylint: disable=W0212 2275 process_id[op_id][1].add(pool.pid) 2276 self.pids.append(pool.pid) 2277 with _LOCK: 2278 _OP_PROCESS.update(process_id) 2279 2280 # Wrap per_batch_map into _PythonCallable 2281 self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool, arg_q_list, res_q_list) 2282 self.hook = _ExceptHookHandler() 2283 atexit.register(_mp_pool_exit_preprocess) 2284 # If Python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown. 2285 if sys.version_info >= (3, 8): 2286 atexit.register(self.process_pool.close) 2287 if platform.system().lower() != 'windows': 2288 self.eot = threading.Event() 2289 self.watch_dog = threading.Thread(target=_watch_dog, args=(self.eot, self.pids)) 2290 self.watch_dog.daemon = True 2291 self.watch_dog.start() 2292 else: 2293 if self.per_batch_map is not None: 2294 self.per_batch_map = FuncWrapper(self.per_batch_map) 2295 2296 def _abort_watchdog(self): 2297 if not self.eot.is_set(): 2298 self.eot.set() 2299 2300 def __del__(self): 2301 if hasattr(self, 'process_pool') and self.process_pool is not None: 2302 self.process_pool.close() 2303 if hasattr(self, 'watch_dog') and self.watch_dog is not None and hasattr(self, 'eot') and self.eot is not None: 2304 self._abort_watchdog() 2305 2306 2307class BatchInfo(cde.CBatchInfo): 2308 """ 2309 The information object associates with the current batch of tensors. 2310 """ 2311 2312 def get_batch_num(self): 2313 """ 2314 Return the batch number of the current batch. 2315 """ 2316 return 2317 2318 def get_epoch_num(self): 2319 """ 2320 Return the epoch number of the current batch. 2321 """ 2322 return 2323 2324 2325class BlockReleasePair: 2326 """ 2327 The blocking condition class used by SyncWaitDataset. 2328 2329 Args: 2330 init_release_rows (int): Number of lines to allow through the pipeline. 2331 callback (function): The callback function that will be called when release is called (default=None). 2332 """ 2333 2334 def __init__(self, init_release_rows, callback=None): 2335 if isinstance(init_release_rows, int) and init_release_rows <= 0: 2336 raise ValueError("release_rows need to be greater than 0.") 2337 self.row_count = -init_release_rows 2338 self.cv = threading.Condition() 2339 self.callback = callback 2340 self.default_rows = init_release_rows 2341 self.disable = False 2342 2343 def __deepcopy__(self, memodict): 2344 return self 2345 2346 def reset(self): 2347 with self.cv: 2348 self.row_count = -self.default_rows 2349 self.cv.notify_all() 2350 2351 def update_batched_size(self, batch_size): 2352 # sanity check 2353 if isinstance(batch_size, int) and batch_size <= 0: 2354 raise ValueError("batch_size need to be greater than 0.") 2355 2356 # should only use before the pipeline creates 2357 self.row_count *= batch_size 2358 self.default_rows *= batch_size 2359 2360 def block_func(self): 2361 """ 2362 Function for handing blocking condition. 2363 2364 Returns: 2365 bool, True. 2366 """ 2367 with self.cv: 2368 # if disable is true, the always evaluate to true 2369 not_time_out = self.cv.wait_for(lambda: (self.row_count < 0 or self.disable), 2370 timeout=get_callback_timeout()) 2371 # time_out will be False if time out occurs 2372 if not not_time_out: 2373 logger.warning("Timeout happened in sync_wait, maybe dataset.sync_update(condition=...) " 2374 "is not added after dataset.create_dict_iterator(...), now disabling lock.") 2375 self.disable = True 2376 self.row_count += 1 2377 return True 2378 2379 def release_func(self, pass_rows=None, data=None): 2380 with self.cv: 2381 if pass_rows is None: 2382 pass_rows = self.default_rows 2383 self.row_count -= pass_rows 2384 if self.callback is not None: 2385 self.callback(data) 2386 self.cv.notify_all() 2387 2388 def disable_lock(self): 2389 with self.cv: 2390 self.disable = True 2391 self.cv.notify_all() 2392 2393 2394class SyncWaitDataset(Dataset): 2395 """ 2396 The result of adding a blocking condition to the input Dataset. 2397 2398 Args: 2399 input_dataset (Dataset): Input dataset to apply flow control. 2400 num_batch (int): Number of batches without blocking at the start of each epoch. 2401 condition_name (str): Condition name that is used to toggle sending next row. 2402 callback (function): Callback function that will be invoked when sync_update is called (default=None). 2403 2404 Raises: 2405 RuntimeError: If condition name already exists. 2406 """ 2407 2408 def __init__(self, input_dataset, condition_name, num_batch, callback=None): 2409 super().__init__(children=input_dataset) 2410 2411 # set to the default value, waiting for the batch to update it 2412 self._condition_name = condition_name 2413 if isinstance(num_batch, int) and num_batch <= 0: 2414 raise ValueError("num_batch need to be greater than 0.") 2415 2416 self._pair = BlockReleasePair(num_batch, callback) 2417 if self._condition_name in self.children[0].get_sync_notifiers(): 2418 raise RuntimeError("Condition name is already in use.") 2419 logger.info("Please remember to add dataset.sync_update(condition=%s), otherwise hanging will result. " 2420 "If dataset.sync_update(condition=%s) has already been added, you can ignore the info.", 2421 condition_name, condition_name) 2422 2423 def parse(self, children=None): 2424 return cde.SyncWaitNode(children[0], self._condition_name, self._pair.block_func) 2425 2426 def get_sync_notifiers(self): 2427 return {**self.children[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}} 2428 2429 def is_sync(self): 2430 return True 2431 2432 def update_sync_batch_size(self, batch_size): 2433 if isinstance(batch_size, int) and batch_size <= 0: 2434 raise ValueError("num_batch need to be greater than 0.") 2435 self._pair.update_batched_size(batch_size) 2436 2437 def disable_sync(self): 2438 logger.info("Disabling Sync") 2439 self._pair.disable_lock() 2440 2441 @staticmethod 2442 def _is_ancestor_of_batch(dataset): 2443 """ 2444 Utility function to find the case where sync_wait is used before batch. 2445 2446 Args: 2447 dataset (Dataset): Dataset to be checked. 2448 2449 Returns: 2450 bool, whether sync_wait is used before batch. 2451 """ 2452 if isinstance(dataset, BatchDataset): 2453 return True 2454 flag = False 2455 for input_dataset in dataset.children: 2456 flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset) 2457 return flag 2458 2459 def iterator_bootstrap(self): 2460 self._pair.reset() 2461 2462 2463class ShuffleDataset(Dataset): 2464 """ 2465 The result of applying Shuffle operator to the input Dataset. 2466 2467 Args: 2468 input_dataset (Dataset): Input Dataset to be shuffled. 2469 buffer_size (int): Size of the buffer. 2470 2471 Raises: 2472 RuntimeError: If exist sync operators before shuffle. 2473 """ 2474 2475 def __init__(self, input_dataset, buffer_size): 2476 super().__init__(children=input_dataset) 2477 self.buffer_size = buffer_size 2478 self.reshuffle_each_epoch = True 2479 2480 if self.is_sync(): 2481 raise RuntimeError("No shuffle after sync operators.") 2482 2483 def parse(self, children=None): 2484 return cde.ShuffleNode(children[0], self.buffer_size, self.reshuffle_each_epoch) 2485 2486 def is_shuffled(self): 2487 return True 2488 2489 2490# This wait function is for cleaning zombie subprocesses 2491def wait_pid(): 2492 try: 2493 while True: 2494 child_pid, _ = os.waitpid(-1, os.WNOHANG) 2495 if child_pid == 0: 2496 break 2497 except OSError: 2498 # waitpid may be failed for some reasons so we ignore this error 2499 pass 2500 2501 2502# Dataset need _watch_dog thread to monitoring fork multi-processing, 2503# and thread can't be a member function otherwise python won't collect and release resources. 2504def _watch_dog(eot, pids): 2505 """ 2506 This thread is for monitoring subprocesses forked by GeneratorDataset/MapDataset/BatchDataset 2507 """ 2508 while not eot.is_set(): 2509 subprocess_exit_num = 0 2510 # Monitoring and count how many subprocesses already exit 2511 for pid in pids: 2512 try: 2513 p = psutil.Process(pid) 2514 if p.status() == psutil.STATUS_ZOMBIE: 2515 subprocess_exit_num += 1 2516 except psutil.NoSuchProcess: 2517 subprocess_exit_num += 1 2518 # If find subprocess exit, we will wait for 30s and do some waitpid operations 2519 if subprocess_exit_num > 0: 2520 start = time.time() 2521 while time.time() - start < 30: 2522 # We need to distinguishing get_dataset_size or train finished normally and hang scenario. 2523 # If get_dataset_size or train finished normally, _stop_subprocess can be execute and 2524 # self.need_abort can be set to True. If main process is hang in get(), self.need_abort 2525 # will never set to True, then we wait for 30s and kill main process 2526 if eot.is_set(): 2527 return 2528 # Sometimes subprocess may be zombie, so in 30s we can wait and do some useful tasks(waitpid). 2529 wait_pid() 2530 ## multiprocessing.queue may hang in .get() forever when put() process was killed. 2531 ## We have to exit main process otherwise main process will hang. 2532 logger.error("The subprocess of dataset may exit unexpected or be killed, " 2533 "main process will exit.") 2534 os.kill(os.getpid(), signal.SIGTERM) 2535 2536 2537# Pyfunc collection for multiprocess pyfunc 2538# This global variable will only be used within subprocesses 2539_GLOBAL_PYFUNC_LIST = [] 2540_ARGS_QUEUE = [] 2541_RET_QUEUE = [] 2542_OP_NAME = dict() 2543_OP_PROCESS = dict() 2544_LOCK = threading.Lock() 2545 2546 2547# Pyfunc worker init function 2548# Python multiprocessing library forbid sending lambda function through pipe. 2549# This init function allow us to add all Python function to a global collection and then fork afterwards. 2550def _pyfunc_worker_init(pyfunc_list, args_queue, ret_queue): 2551 global _GLOBAL_PYFUNC_LIST 2552 global _ARGS_QUEUE 2553 global _RET_QUEUE 2554 _GLOBAL_PYFUNC_LIST = pyfunc_list 2555 _ARGS_QUEUE = args_queue 2556 _RET_QUEUE = ret_queue 2557 2558 2559# Pyfunc worker execution function 2560# All exceptions will be raised to main processes 2561def _pyfunc_worker_exec(index, qid, *args): 2562 """ 2563 Internal function for call certain pyfunc in Python process. 2564 """ 2565 # Some threads in multiprocess.pool can't process sigint signal, 2566 # and will occur hang problem, so ctrl+c will pass to parent process. 2567 signal.signal(signal.SIGINT, signal.SIG_IGN) 2568 2569 if qid != -1: 2570 # Pass arguments through the Queue instead of directly to remote process 2571 args = _ARGS_QUEUE[qid].get() 2572 try: 2573 r = _GLOBAL_PYFUNC_LIST[index](*args) 2574 except Exception: 2575 return ExceptionHandler(where="in map(or batch) worker and execute python function") 2576 if isinstance(r, tuple): 2577 _RET_QUEUE[qid].put(r) 2578 else: 2579 _RET_QUEUE[qid].put((r,)) 2580 return [qid] 2581 # not using shared memory for passing arguments, call function directly 2582 result = None 2583 try: 2584 result = _GLOBAL_PYFUNC_LIST[index](*args) 2585 except Exception: 2586 result = ExceptionHandler(where="in map(or batch) worker and execute python function") 2587 return result 2588 2589 2590# PythonCallable wrapper for multiprocess pyfunc 2591class _PythonCallable: 2592 """ 2593 Internal Python function wrapper for multiprocessing pyfunc. 2594 """ 2595 2596 def __init__(self, py_callable, idx, pool=None, arg_q=None, res_q=None): 2597 # Original Python callable from user. 2598 self.py_callable = py_callable 2599 # Process pool created for current iterator. 2600 self.pool = pool 2601 # Python callable index for subprocess _GLOBAL_PYFUNC_LIST 2602 self.idx = idx 2603 2604 if pool is not None: 2605 self.queuemap = {} 2606 self.arg_q = arg_q 2607 self.res_q = res_q 2608 self.next_queue = 0 2609 2610 def __call__(self, *args): 2611 if self._pool_is_running() and check_iterator_cleanup() is False: 2612 # arg_q will have 0 size if we are not using shared memory 2613 # if using multi-processing shared queue instead of multiprocess arg passing 2614 if self.arg_q != []: 2615 tid = threading.get_ident() 2616 # Need to register each thread to use a different queue to send data to pool 2617 if not tid in self.queuemap: 2618 qid = self.next_queue 2619 self.next_queue = self.next_queue + 1 2620 self.queuemap[tid] = qid 2621 else: 2622 qid = self.queuemap[tid] 2623 self.arg_q[qid].put(args) 2624 2625 # This call will send the tensors along with Python callable index to the process pool. 2626 # Block, yield GIL. Current thread will reacquire GIL once result is returned. 2627 if self._pool_is_running() and check_iterator_cleanup() is False: 2628 result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, qid, []]) 2629 else: 2630 return self.py_callable(*args) 2631 else: 2632 result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, -1, *args]) 2633 2634 # todo this check might be wrong 2635 while check_iterator_cleanup() is False: 2636 try: 2637 if self.arg_q != []: 2638 r = result.get(30) 2639 if isinstance(r, ExceptionHandler): 2640 r.reraise() 2641 if r[0] != qid: 2642 raise Exception("In PyCallable, got results from wrong thread") 2643 r = self.res_q[qid].get() 2644 return r 2645 r = result.get(30) 2646 if isinstance(r, ExceptionHandler): 2647 r.reraise() 2648 return r 2649 except multiprocessing.TimeoutError: 2650 continue 2651 except KeyboardInterrupt: 2652 _set_iterator_cleanup() 2653 self.pool.close() 2654 self.pool.join() 2655 raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt.") 2656 return (None,) 2657 # Invoke original Python callable in master process in case the pool is gone. 2658 return self.py_callable(*args) 2659 2660 def to_json(self): 2661 return self.py_callable.to_json() 2662 2663 def _pool_is_running(self): 2664 # note here: the RUN state of python3.7 and python3.8 is different: 2665 # python3.7: RUN = 0 2666 # python3.8: RUN = "RUN" 2667 # so we use self.pool._state == RUN instead and we can't use _state == 0 any more. 2668 if self.pool is not None and self.pool._state == RUN: # pylint: disable=W0212 2669 return True 2670 return False 2671 2672 2673def _mp_pool_exit_preprocess(): 2674 if check_iterator_cleanup() is False: 2675 # Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async 2676 # applied to the multiprocessing task to prevent multiprocessing from hang when exiting 2677 _set_iterator_cleanup() 2678 time.sleep(3) 2679 2680 2681class _ExceptHookHandler: 2682 def __init__(self): 2683 sys.excepthook = self.__handler_exception 2684 2685 def __handler_exception(self, ex_type, value, tb): 2686 logger.error("Uncaught exception: ", exc_info=(ex_type, value, tb)) 2687 _mp_pool_exit_preprocess() 2688 2689 2690class MapDataset(Dataset): 2691 """ 2692 The result of applying the Map operator to the input Dataset. 2693 2694 Args: 2695 input_dataset (Dataset): Input Dataset to be mapped. 2696 operations (TensorOp): A function mapping a nested structure of tensors 2697 to another nested structure of tensor (default=None). 2698 input_columns (Union[str, list[str]]): List of names of the input columns 2699 (default=None, the operations will be applied on the first columns in the dataset). 2700 The size of the list should match the number of inputs of the first operator. 2701 output_columns (Union[str, list[str]], optional): List of names of the output columns. 2702 The size of the list should match the number of outputs of the last operator 2703 (default=None, output columns will be the input columns, i.e., the columns will 2704 be replaced). 2705 column_order (list[str], optional): Specifies the list of all the columns you need in the whole 2706 dataset. The parameter is required when len(input_column) != len(output_column). Caution: the list here 2707 is not just the columns specified in parameter input_columns and output_columns. 2708 num_parallel_workers (int, optional): Number of workers to process the dataset 2709 in parallel (default=None). 2710 python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This 2711 option could be beneficial if the Python operation is computational heavy (default=False). 2712 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 2713 (default=None, which means no cache is used). 2714 callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None) 2715 max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy 2716 data between processes. This is only used if python_multiprocessing is set to True (default 16 MB). 2717 2718 Raises: 2719 ValueError: If len(input_columns) != len(output_columns) and column_order is not specified. 2720 """ 2721 2722 def __init__(self, input_dataset, operations=None, input_columns=None, output_columns=None, column_order=None, 2723 num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16): 2724 super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers, cache=cache) 2725 self.operations = to_list(operations) 2726 self.operations = py_transforms.Compose.reduce(self.operations) 2727 self.input_columns = to_list(input_columns) 2728 self.output_columns = to_list(output_columns) 2729 self.column_order = replace_none(column_order, []) 2730 2731 # If output_columns were not provided then use input_columns 2732 self.output_columns = self.input_columns if not self.output_columns else self.output_columns 2733 2734 if self.input_columns and self.output_columns \ 2735 and len(self.input_columns) != len(self.output_columns) \ 2736 and not self.column_order: 2737 raise ValueError("When length of input_columns and output_columns are not equal," 2738 " column_order must be specified.") 2739 2740 self.python_multiprocessing = python_multiprocessing 2741 self.process_pool = None 2742 self.hook = None 2743 self.pids = [] 2744 self.eot = None 2745 self.watch_dog = None 2746 2747 self.callbacks = to_list(callbacks) 2748 self.max_rowsize = max_rowsize 2749 2750 def parse(self, children=None): 2751 operations = [] 2752 for op in self.operations: 2753 if op and getattr(op, 'parse', None): 2754 operations.append(op.parse()) 2755 else: 2756 operations.append(op) 2757 2758 callbacks = [cb.create_runtime_obj() for cb in self.callbacks] 2759 return cde.MapNode(children[0], operations, self.input_columns, self.output_columns, self.column_order, 2760 callbacks) 2761 2762 def __deepcopy__(self, memodict): 2763 return self.__safe_deepcopy__(memodict, exclude=("operations", "callbacks", "__transfer_dataset__")) 2764 2765 # Iterator bootstrap will be called on iterator construction. 2766 # A deep copy of Dataset object is created prior of iterator_bootstrap. 2767 # This method will create per iterator process pool and bind pyfunc execution to the pool. 2768 def iterator_bootstrap(self): 2769 """ 2770 Per iterator bootstrap callback. 2771 """ 2772 2773 if self.python_multiprocessing: 2774 iter_specific_operations = [] 2775 callable_list = [] 2776 arg_q_list = [] 2777 res_q_list = [] 2778 2779 # If user didn't specify num_parallel_workers, set it to default 2780 if self.num_parallel_workers is not None: 2781 num_parallel = self.num_parallel_workers 2782 else: 2783 num_parallel = get_num_parallel_workers() 2784 2785 if get_enable_shared_mem(): 2786 _check_shm_usage(num_parallel, 1, self.max_rowsize, 2) 2787 for _ in range(num_parallel): 2788 arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize)) 2789 res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize)) 2790 2791 # Pass #1, look for Python callables and build list 2792 for op in self.operations: 2793 # our c transforms is now callable and should not be run in Python multithreading 2794 if callable(op) and str(op).find("c_transform") < 0: 2795 callable_list.append(op) 2796 2797 if callable_list: 2798 # Construct pool with the callable list 2799 # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses 2800 self.process_pool = multiprocessing.Pool(processes=num_parallel, 2801 initializer=_pyfunc_worker_init, 2802 initargs=(callable_list, arg_q_list, res_q_list)) 2803 2804 # Pass #2 2805 idx = 0 2806 global _OP_NAME, _OP_PROCESS, _LOCK 2807 op_id = _OP_NAME[str(self)] 2808 # obtain process id from multiprocessing.pool 2809 process_id = {op_id: [self.num_parallel_workers, set()]} 2810 for pool in self.process_pool._pool: # pylint: disable=W0212 2811 process_id[op_id][1].add(pool.pid) 2812 self.pids.append(pool.pid) 2813 with _LOCK: 2814 _OP_PROCESS.update(process_id) 2815 for op in self.operations: 2816 # our c transforms is now callable and should not be run in Python multithreading 2817 if callable(op) and str(op).find("c_transform") < 0: 2818 # Wrap Python callable into _PythonCallable 2819 iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool, 2820 arg_q_list, res_q_list)) 2821 idx += 1 2822 else: 2823 # CPP ops remain the same 2824 iter_specific_operations.append(op) 2825 self.operations = iter_specific_operations 2826 self.hook = _ExceptHookHandler() 2827 atexit.register(_mp_pool_exit_preprocess) 2828 # If Python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown. 2829 if sys.version_info >= (3, 8): 2830 atexit.register(self.process_pool.close) 2831 if platform.system().lower() != 'windows': 2832 self.eot = threading.Event() 2833 self.watch_dog = threading.Thread(target=_watch_dog, args=(self.eot, self.pids)) 2834 self.watch_dog.daemon = True 2835 self.watch_dog.start() 2836 2837 def _abort_watchdog(self): 2838 if not self.eot.is_set(): 2839 self.eot.set() 2840 2841 def __del__(self): 2842 if hasattr(self, 'process_pool') and self.process_pool is not None: 2843 self.process_pool.close() 2844 self.process_pool.join() 2845 if hasattr(self, 'watch_dog') and self.watch_dog is not None and hasattr(self, 'eot') and self.eot is not None: 2846 self._abort_watchdog() 2847 2848 2849class FilterDataset(Dataset): 2850 """ 2851 The result of applying filter predicate to the input Dataset. 2852 2853 Args: 2854 input_dataset (Dataset): Input Dataset to be mapped. 2855 predicate (callable): Python callable which returns a boolean value. If False then filter the element. 2856 input_columns (Union[str, list[str]], optional): List of names of the input columns 2857 (default=None, the predicate will be applied to all columns in the dataset). 2858 num_parallel_workers (int, optional): Number of workers to process the dataset 2859 in parallel (default=None). 2860 """ 2861 2862 def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None): 2863 super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers) 2864 self.predicate = lambda *args: bool(predicate(*args)) 2865 self.input_columns = to_list(input_columns) 2866 2867 def parse(self, children=None): 2868 return cde.FilterNode(children[0], self.predicate, self.input_columns) 2869 2870 2871class RepeatDataset(Dataset): 2872 """ 2873 The result of applying Repeat operator to the input Dataset. 2874 2875 Args: 2876 input_dataset (Dataset): Input Dataset to be repeated. 2877 count (int): Number of times the dataset will be repeated (default=-1, repeat indefinitely). 2878 """ 2879 2880 def __init__(self, input_dataset, count): 2881 super().__init__(children=input_dataset) 2882 self.count = replace_none(count, -1) 2883 2884 def parse(self, children=None): 2885 return cde.RepeatNode(children[0], self.count) 2886 2887 2888class SkipDataset(Dataset): 2889 """ 2890 The result of applying Skip operator to the input Dataset. 2891 2892 Args: 2893 input_dataset (Dataset): Input dataset to have elements skipped. 2894 count (int): Number of elements to be skipped in the dataset. 2895 """ 2896 2897 def __init__(self, input_dataset, count): 2898 super().__init__(input_dataset) 2899 self.count = count 2900 2901 def parse(self, children=None): 2902 return cde.SkipNode(children[0], self.count) 2903 2904 2905class TakeDataset(Dataset): 2906 """ 2907 The result of applying Take operator to the input Dataset. 2908 2909 Args: 2910 input_dataset (Dataset): Input Dataset to have elements taken from. 2911 count (int): Number of elements to be taken from the dataset. 2912 """ 2913 2914 def __init__(self, input_dataset, count): 2915 super().__init__(children=input_dataset) 2916 self.count = count 2917 2918 def parse(self, children=None): 2919 return cde.TakeNode(children[0], self.count) 2920 2921 2922class ZipDataset(Dataset): 2923 """ 2924 The result of applying Zip operator to the input Dataset. 2925 2926 Args: 2927 datasets (tuple): A tuple of datasets to be zipped together. 2928 2929 Raises: 2930 TypeError: If dataset is not an instance of Dataset. 2931 """ 2932 2933 def __init__(self, datasets): 2934 super().__init__(children=datasets) 2935 2936 def parse(self, children=None): 2937 return cde.ZipNode(children) 2938 2939 def is_sync(self): 2940 return any([c.is_sync() for c in self.children]) 2941 2942 2943class ConcatDataset(Dataset): 2944 """ 2945 The result of applying concat dataset operator to the input Dataset. 2946 2947 Args: 2948 datasets (list): A list of datasets to be concatenated together. 2949 2950 Raises: 2951 TypeError: If dataset is not an instance of Dataset. 2952 ValueError: If there is no samples in the one of the datasets. 2953 """ 2954 2955 def __init__(self, datasets): 2956 super().__init__(children=datasets) 2957 for dataset in datasets: 2958 if not isinstance(dataset, Dataset): 2959 raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset)) 2960 self.datasets = datasets 2961 self._sampler = samplers.SequentialSampler(num_samples=None) 2962 2963 self.children_sizes_ = [c.get_dataset_size() for c in self.children] 2964 child_index = 0 2965 for item in self.children_sizes_: 2966 if item == 0: 2967 raise ValueError("There are no samples in the dataset number %d. Please make sure there are " 2968 "valid samples in the dataset." % child_index) 2969 child_index += 1 2970 2971 # _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes 2972 # whether the data set is mappable. The second element of pair is length of the dataset 2973 self._children_flag_and_nums = [] 2974 2975 # _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize 2976 # the valid position of the dataset corresponding to the subscript when sampling 2977 self._children_start_end_index_ = [] 2978 for index, child in enumerate(self.children): 2979 tem_list = [-1, -1] 2980 self._children_start_end_index_.append(tem_list) 2981 dataset_len = self.children_sizes_[index] 2982 if isinstance(child, GeneratorDataset) and not hasattr(child.source, "__getitem__"): 2983 dataset_len = 0 2984 self.children_sizes_[index] = 0 2985 2986 if isinstance(child, MappableDataset): 2987 self._children_flag_and_nums.append((0, dataset_len)) 2988 else: 2989 self._children_flag_and_nums.append((1, dataset_len)) 2990 2991 def parse(self, children=None): 2992 return cde.ConcatNode(children, self._sampler, self._children_flag_and_nums, self._children_start_end_index_) 2993 2994 def use_sampler(self, sampler): 2995 """ 2996 Set the distributedSampler to concat dataset 2997 2998 Args: 2999 sampler (Sampler): The sampler to use for the current dataset. 3000 Currently supported: DistributedSampler. 3001 3002 Raises: 3003 TypeError: If the sampler is not an instance of DistributedSampler 3004 ValueError: If the parameter shuffle of sampler is True 3005 ValueError: If the parameter NumSamples of sampler is not None. 3006 ValueError: If num_shards <=0. 3007 """ 3008 if not isinstance(sampler, samplers.DistributedSampler): 3009 raise TypeError("The parameter %s of concat must be DistributedSampler!" % sampler) 3010 3011 if sampler.is_shuffled(): 3012 raise ValueError("The parameter shuffle of DistributedSampler must be False!") 3013 3014 if sampler.num_shards <= 0: 3015 raise ValueError("The parameter num_shards of DistributedSampler must be positive int!") 3016 3017 if sampler.get_num_samples() is not None: 3018 raise ValueError("The parameter num_samples of DistributedSampler is not support to be set!") 3019 3020 self.dataset_size = None 3021 3022 self._sampler = sampler 3023 cumulative_samples_nums = 0 3024 for index, child in enumerate(self.children): 3025 if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None: 3026 raise ValueError("The parameter NumSamples of %s is not support to be set!" % child) 3027 3028 if isinstance(child, BatchDataset): 3029 raise TypeError("The parameter %s of concat must not be BatchDataset!" % child) 3030 3031 # if child is mappable and the length is greater than 0 3032 if not self._children_flag_and_nums[index][0] and self._children_flag_and_nums[index][1]: 3033 3034 tem_value = cumulative_samples_nums + self._children_flag_and_nums[index][1] 3035 3036 if not self._children_flag_and_nums[index][1] >= sampler.num_shards: 3037 if tem_value < sampler.num_shards: 3038 self._children_start_end_index_[index][0] = cumulative_samples_nums 3039 self._children_start_end_index_[index][1] = tem_value 3040 else: 3041 self._children_start_end_index_[index][0] = cumulative_samples_nums 3042 self._children_start_end_index_[index][1] = tem_value % sampler.num_shards 3043 3044 tem_sampler = copy.deepcopy(sampler) 3045 tem_sampler.set_offset(cumulative_samples_nums) 3046 child.use_sampler(tem_sampler) 3047 3048 cumulative_samples_nums += self.children_sizes_[index] 3049 cumulative_samples_nums %= sampler.num_shards 3050 3051 3052class RenameDataset(Dataset): 3053 """ 3054 The result of applying Rename operator to the input Dataset. 3055 3056 Args: 3057 input_dataset (Dataset): Input Dataset to be Renamed. 3058 input_columns (Union[str, list[str]]): List of names of the input columns. 3059 output_columns (Union[str, list[str]]): List of names of the output columns. 3060 """ 3061 3062 def __init__(self, input_dataset, input_columns, output_columns): 3063 super().__init__(children=input_dataset) 3064 self.input_column_names = to_list(input_columns) 3065 self.output_column_names = to_list(output_columns) 3066 3067 def parse(self, children=None): 3068 return cde.RenameNode(children[0], self.input_column_names, self.output_column_names) 3069 3070 3071def to_list(items): 3072 if items is None: 3073 return [] 3074 if isinstance(items, tuple): 3075 return list(items) 3076 if not isinstance(items, list): 3077 return [items] 3078 return items 3079 3080 3081class ProjectDataset(Dataset): 3082 """ 3083 The result of applying Project operator to the input Dataset. 3084 3085 Args: 3086 input_dataset (Dataset): Input Dataset to be Projected. 3087 columns (Union[str, list[str]]): List of names of the columns to project. 3088 """ 3089 3090 def __init__(self, input_dataset, columns): 3091 super().__init__(children=input_dataset) 3092 self.columns = to_list(columns) 3093 3094 def parse(self, children=None): 3095 return cde.ProjectNode(children[0], self.columns) 3096 3097 3098class _ToDevice: 3099 """ 3100 Internal class to handle sending data to device. 3101 """ 3102 3103 def __init__(self, dataset, num_epochs): 3104 ir_tree, self.api_tree = dataset.create_ir_tree() 3105 3106 self._runtime_context = cde.PythonRuntimeContext() 3107 self._runtime_context.Init() 3108 self._to_device = cde.ToDevice(num_epochs) 3109 self._to_device.Init(ir_tree) 3110 self._runtime_context.AssignConsumer(self._to_device) 3111 3112 ITERATORS_LIST.append(weakref.ref(self)) 3113 _unset_iterator_cleanup() 3114 3115 def send(self): 3116 self._to_device.Send() 3117 3118 def stop_send(self): 3119 """ 3120 send stop send signal to pipeline, it is used when end of sequence is sent at the epoch end. 3121 """ 3122 self._to_device.StopSend() 3123 3124 def continue_send(self): 3125 """ 3126 send continue send signal to pipeline, it is used when end of sequence is sent at the epoch end. 3127 """ 3128 self._to_device.ContinueSend() 3129 3130 def get_data_info(self): 3131 """ 3132 Get type and shape of current batch. 3133 """ 3134 return self._to_device.GetDataInfo() 3135 3136 def release(self): 3137 """ 3138 Manually terminate Device Queue instead of relying on out of scope destruction. 3139 """ 3140 if hasattr(self, '_runtime_context') and self._runtime_context: 3141 if hasattr(self, '_to_device') and self._to_device: 3142 self._runtime_context.Terminate() 3143 del self._to_device 3144 del self._runtime_context 3145 3146 def __deepcopy__(self, memodict): 3147 return self 3148 3149 3150class TransferDataset(Dataset): 3151 """ 3152 The result of applying TDT operator to the input Dataset. 3153 3154 Args: 3155 input_dataset (Dataset): Input Dataset to be transferred. 3156 send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True). 3157 create_data_info_queue (bool, optional): Whether to create queue which stores 3158 types and shapes of data or not (default=False). 3159 3160 Raises: 3161 TypeError: If device_type is empty. 3162 ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'. 3163 RuntimeError: If dataset is unknown. 3164 """ 3165 3166 def __init__(self, input_dataset, send_epoch_end=True, create_data_info_queue=False): 3167 super().__init__(children=input_dataset) 3168 self.queue_name = str(uuid.uuid1()) 3169 self.device_type = context.get_context("device_target") if context else "CPU" 3170 self.device_id = context.get_context("device_id") if context else 0 3171 3172 self._send_epoch_end = replace_none(send_epoch_end, True) 3173 self._create_data_info_queue = create_data_info_queue 3174 self._to_device = None 3175 3176 def parse(self, children=None): 3177 total_batch = 0 3178 if hasattr(self.children[0], "__total_batch__"): 3179 total_batch = self.children[0].__total_batch__ 3180 return cde.TransferNode(children[0], self.queue_name, self.device_type, self.device_id, self._send_epoch_end, 3181 total_batch, self._create_data_info_queue) 3182 3183 def create_dict_iterator(self, num_epochs=-1, output_numpy=False): 3184 raise RuntimeError("TransferDataset is not iterable.") 3185 3186 def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True): 3187 raise RuntimeError("TransferDataset is not iterable.") 3188 3189 def __iter__(self): 3190 raise RuntimeError("TransferDataset is not iterable.") 3191 3192 def output_shapes(self): 3193 raise RuntimeError("TransferDataset does not support obtaining output_shapes.") 3194 3195 def output_types(self): 3196 raise RuntimeError("TransferDataset does not support obtaining output_types.") 3197 3198 @check_to_device_send 3199 def send(self, num_epochs=-1): 3200 """ 3201 Send to device 3202 """ 3203 if Dataset._noop_mode(): 3204 return 3205 if self._to_device is not None: 3206 del self._to_device 3207 self._to_device = _ToDevice(self, num_epochs) 3208 self._to_device.send() 3209 3210 def stop_send(self): 3211 if self._to_device is not None: 3212 self._to_device.stop_send() 3213 3214 def continue_send(self): 3215 if self._to_device is not None: 3216 self._to_device.continue_send() 3217 3218 def get_data_info(self): 3219 """ 3220 Get type and shape of current batch 3221 """ 3222 if self._to_device is not None: 3223 return self._to_device.get_data_info() 3224 raise RuntimeError("Calling get_data_info with bad state.") 3225 3226 def release(self): 3227 """ 3228 Manually terminate Device Queue instead of relying on out of scope destruction. 3229 """ 3230 if self._to_device is not None: 3231 self._to_device.release() 3232 3233 3234class RangeDataset(MappableDataset): 3235 """ 3236 A source dataset that reads and parses datasets stored on disk in a range. 3237 3238 Args: 3239 start (int): Starting index. 3240 stop (int): Ending index. 3241 step (int): Step size in the range specified by start and stop. 3242 """ 3243 3244 def __init__(self, start, stop, step): 3245 super().__init__() 3246 self.start = start 3247 self.stop = stop 3248 self.step = step 3249 3250 def parse(self, children=None): 3251 raise NotImplementedError("Dataset has to implement parse method.") 3252 3253 def is_shuffled(self): 3254 return False 3255 3256 def is_sharded(self): 3257 return False 3258 3259 def get_dataset_size(self): 3260 if self.dataset_size is None: 3261 self.dataset_size = math.ceil((self.stop - self.start) / self.step) 3262 return self.dataset_size 3263 3264 3265class ImageFolderDataset(MappableDataset): 3266 """ 3267 A source dataset that reads images from a tree of directories. 3268 All images within one folder have the same label. 3269 3270 The generated dataset has two columns: :py:obj:`[image, label]`. 3271 The tensor of column :py:obj:`image` is of the uint8 type. 3272 The tensor of column :py:obj:`label` is of a scalar of uint32 type. 3273 3274 Args: 3275 dataset_dir (str): Path to the root directory that contains the dataset. 3276 num_samples (int, optional): The number of images to be included in the dataset 3277 (default=None, all images). 3278 num_parallel_workers (int, optional): Number of workers to read the data 3279 (default=None, set in the config). 3280 shuffle (bool, optional): Whether or not to perform shuffle on the dataset 3281 (default=None, expected order behavior shown in the table). 3282 sampler (Sampler, optional): Object used to choose samples from the 3283 dataset (default=None, expected order behavior shown in the table). 3284 extensions (list[str], optional): List of file extensions to be 3285 included in the dataset (default=None). 3286 class_indexing (dict, optional): A str-to-int mapping from folder name to index 3287 (default=None, the folder names will be sorted 3288 alphabetically and each class will be given a 3289 unique index starting from 0). 3290 decode (bool, optional): Decode the images after reading (default=False). 3291 num_shards (int, optional): Number of shards that the dataset will be divided 3292 into (default=None). When this argument is specified, `num_samples` reflects 3293 the maximum sample number of per shard. 3294 shard_id (int, optional): The shard ID within num_shards (default=None). This 3295 argument can only be specified when num_shards is also specified. 3296 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 3297 (default=None, which means no cache is used). 3298 3299 Raises: 3300 RuntimeError: If dataset_dir does not contain data files. 3301 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 3302 RuntimeError: If sampler and shuffle are specified at the same time. 3303 RuntimeError: If sampler and sharding are specified at the same time. 3304 RuntimeError: If num_shards is specified but shard_id is None. 3305 RuntimeError: If shard_id is specified but num_shards is None. 3306 RuntimeError: If class_indexing is not a dictionary. 3307 ValueError: If shard_id is invalid (< 0 or >= num_shards). 3308 3309 Note: 3310 - The shape of the image column is [image_size] if decode flag is False, or [H,W,C] otherwise. 3311 - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. 3312 The table below shows what input arguments are allowed and their expected behavior. 3313 3314 .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` 3315 :widths: 25 25 50 3316 :header-rows: 1 3317 3318 * - Parameter `sampler` 3319 - Parameter `shuffle` 3320 - Expected Order Behavior 3321 * - None 3322 - None 3323 - random order 3324 * - None 3325 - True 3326 - random order 3327 * - None 3328 - False 3329 - sequential order 3330 * - Sampler object 3331 - None 3332 - order defined by sampler 3333 * - Sampler object 3334 - True 3335 - not allowed 3336 * - Sampler object 3337 - False 3338 - not allowed 3339 3340 Examples: 3341 >>> image_folder_dataset_dir = "/path/to/image_folder_dataset_directory" 3342 >>> 3343 >>> # 1) Read all samples (image files) in image_folder_dataset_dir with 8 threads 3344 >>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir, 3345 ... num_parallel_workers=8) 3346 >>> 3347 >>> # 2) Read all samples (image files) from folder cat and folder dog with label 0 and 1 3348 >>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir, 3349 ... class_indexing={"cat":0, "dog":1}) 3350 >>> 3351 >>> # 3) Read all samples (image files) in image_folder_dataset_dir with extensions .JPEG and .png (case sensitive) 3352 >>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir, 3353 ... extensions=[".JPEG", ".png"]) 3354 3355 About ImageFolderDataset: 3356 3357 You can construct the following directory structure from your dataset files and read by MindSpore's API. 3358 3359 .. code-block:: 3360 3361 . 3362 └── image_folder_dataset_directory 3363 ├── class1 3364 │ ├── 000000000001.jpg 3365 │ ├── 000000000002.jpg 3366 │ ├── ... 3367 ├── class2 3368 │ ├── 000000000001.jpg 3369 │ ├── 000000000002.jpg 3370 │ ├── ... 3371 ├── class3 3372 │ ├── 000000000001.jpg 3373 │ ├── 000000000002.jpg 3374 │ ├── ... 3375 ├── classN 3376 ├── ... 3377 """ 3378 3379 @check_imagefolderdataset 3380 def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, 3381 extensions=None, class_indexing=None, decode=False, num_shards=None, shard_id=None, cache=None): 3382 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 3383 shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) 3384 3385 self.dataset_dir = dataset_dir 3386 self.extensions = replace_none(extensions, []) 3387 self.class_indexing = replace_none(class_indexing, {}) 3388 self.decode = replace_none(decode, False) 3389 3390 def parse(self, children=None): 3391 return cde.ImageFolderNode(self.dataset_dir, self.decode, self.sampler, self.extensions, self.class_indexing) 3392 3393 3394class MnistDataset(MappableDataset): 3395 """ 3396 A source dataset for reading and parsing the MNIST dataset. 3397 3398 The generated dataset has two columns :py:obj:`[image, label]`. 3399 The tensor of column :py:obj:`image` is of the uint8 type. 3400 The tensor of column :py:obj:`label` is a scalar of the uint32 type. 3401 3402 Args: 3403 dataset_dir (str): Path to the root directory that contains the dataset. 3404 usage (str, optional): Usage of this dataset, can be `train`, `test` or `all` . `train` will read from 60,000 3405 train samples, `test` will read from 10,000 test samples, `all` will read from all 70,000 samples. 3406 (default=None, will read all samples) 3407 num_samples (int, optional): The number of images to be included in the dataset 3408 (default=None, will read all images). 3409 num_parallel_workers (int, optional): Number of workers to read the data 3410 (default=None, will use value set in the config). 3411 shuffle (bool, optional): Whether or not to perform shuffle on the dataset 3412 (default=None, expected order behavior shown in the table). 3413 sampler (Sampler, optional): Object used to choose samples from the 3414 dataset (default=None, expected order behavior shown in the table). 3415 num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). 3416 When this argument is specified, `num_samples` reflects the maximum sample number of per shard. 3417 shard_id (int, optional): The shard ID within `num_shards` (default=None). This 3418 argument can only be specified when `num_shards` is also specified. 3419 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 3420 (default=None, which means no cache is used). 3421 3422 Raises: 3423 RuntimeError: If dataset_dir does not contain data files. 3424 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 3425 RuntimeError: If sampler and shuffle are specified at the same time. 3426 RuntimeError: If sampler and sharding are specified at the same time. 3427 RuntimeError: If num_shards is specified but shard_id is None. 3428 RuntimeError: If shard_id is specified but num_shards is None. 3429 ValueError: If shard_id is invalid (< 0 or >= num_shards). 3430 3431 Note: 3432 - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. 3433 The table below shows what input arguments are allowed and their expected behavior. 3434 3435 .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` 3436 :widths: 25 25 50 3437 :header-rows: 1 3438 3439 * - Parameter `sampler` 3440 - Parameter `shuffle` 3441 - Expected Order Behavior 3442 * - None 3443 - None 3444 - random order 3445 * - None 3446 - True 3447 - random order 3448 * - None 3449 - False 3450 - sequential order 3451 * - Sampler object 3452 - None 3453 - order defined by sampler 3454 * - Sampler object 3455 - True 3456 - not allowed 3457 * - Sampler object 3458 - False 3459 - not allowed 3460 3461 Examples: 3462 >>> mnist_dataset_dir = "/path/to/mnist_dataset_directory" 3463 >>> 3464 >>> # Read 3 samples from MNIST dataset 3465 >>> dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, num_samples=3) 3466 >>> 3467 >>> # Note: In mnist_dataset dataset, each dictionary has keys "image" and "label" 3468 3469 About MNIST dataset: 3470 3471 The MNIST database of handwritten digits has a training set of 60,000 examples, 3472 and a test set of 10,000 examples. It is a subset of a larger set available from 3473 NIST. The digits have been size-normalized and centered in a fixed-size image. 3474 3475 Here is the original MNIST dataset structure. 3476 You can unzip the dataset files into this directory structure and read by MindSpore's API. 3477 3478 .. code-block:: 3479 3480 . 3481 └── mnist_dataset_dir 3482 ├── t10k-images-idx3-ubyte 3483 ├── t10k-labels-idx1-ubyte 3484 ├── train-images-idx3-ubyte 3485 └── train-labels-idx1-ubyte 3486 3487 Citation: 3488 3489 .. code-block:: 3490 3491 @article{lecun2010mnist, 3492 title = {MNIST handwritten digit database}, 3493 author = {LeCun, Yann and Cortes, Corinna and Burges, CJ}, 3494 journal = {ATT Labs [Online]}, 3495 volume = {2}, 3496 year = {2010}, 3497 howpublished = {http://yann.lecun.com/exdb/mnist} 3498 } 3499 """ 3500 3501 @check_mnist_cifar_dataset 3502 def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, 3503 sampler=None, num_shards=None, shard_id=None, cache=None): 3504 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 3505 shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) 3506 3507 self.dataset_dir = dataset_dir 3508 self.usage = replace_none(usage, "all") 3509 3510 def parse(self, children=None): 3511 return cde.MnistNode(self.dataset_dir, self.usage, self.sampler) 3512 3513 3514class MindDataset(MappableDataset): 3515 """ 3516 A source dataset for reading and parsing MindRecord dataset. 3517 3518 The columns of generated dataset depend on the source MindRecord files. 3519 3520 Args: 3521 dataset_file (Union[str, list[str]]): If dataset_file is a str, it represents for 3522 a file name of one component of a mindrecord source, other files with identical source 3523 in the same path will be found and loaded automatically. If dataset_file is a list, 3524 it represents for a list of dataset files to be read directly. 3525 columns_list (list[str], optional): List of columns to be read (default=None). 3526 num_parallel_workers (int, optional): The number of readers (default=None). 3527 shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch 3528 (default=None, performs global shuffle). 3529 If shuffle is False, no shuffling will be performed; 3530 If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL 3531 Otherwise, there are three levels of shuffling: 3532 3533 - Shuffle.GLOBAL: Global shuffle of all rows of data in dataset. 3534 3535 - Shuffle.FILES: Shuffle the file sequence but keep the order of data within each file. 3536 3537 - Shuffle.INFILE: Keep the file sequence the same but shuffle the data within each file. 3538 3539 num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). 3540 When this argument is specified, 'num_samples' reflects the maximum sample number of per shard. 3541 shard_id (int, optional): The shard ID within num_shards (default=None). This 3542 argument can only be specified when num_shards is also specified. 3543 sampler (Sampler, optional): Object used to choose samples from the 3544 dataset (default=None, sampler is exclusive 3545 with shuffle and block_reader). Support list: SubsetRandomSampler, 3546 PkSampler, RandomSampler, SequentialSampler, DistributedSampler. 3547 padded_sample (dict, optional): Samples will be appended to dataset, where 3548 keys are the same as column_list. 3549 num_padded (int, optional): Number of padding samples. Dataset size 3550 plus num_padded should be divisible by num_shards. 3551 num_samples (int, optional): The number of samples to be included in the dataset 3552 (default=None, all samples). 3553 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 3554 (default=None, which means no cache is used). 3555 3556 Raises: 3557 RuntimeError: If dataset_files are not valid or do not exist. 3558 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 3559 RuntimeError: If num_shards is specified but shard_id is None. 3560 RuntimeError: If shard_id is specified but num_shards is None. 3561 ValueError: If shard_id is invalid (< 0 or >= num_shards). 3562 3563 Note: 3564 - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. 3565 The table below shows what input arguments are allowed and their expected behavior. 3566 3567 .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` 3568 :widths: 25 25 50 3569 :header-rows: 1 3570 3571 * - Parameter `sampler` 3572 - Parameter `shuffle` 3573 - Expected Order Behavior 3574 * - None 3575 - None 3576 - random order 3577 * - None 3578 - True 3579 - random order 3580 * - None 3581 - False 3582 - sequential order 3583 * - Sampler object 3584 - None 3585 - order defined by sampler 3586 * - Sampler object 3587 - True 3588 - not allowed 3589 * - Sampler object 3590 - False 3591 - not allowed 3592 3593 Examples: 3594 >>> mind_dataset_dir = ["/path/to/mind_dataset_file"] # contains 1 or multiple MindRecord files 3595 >>> dataset = ds.MindDataset(dataset_file=mind_dataset_dir) 3596 """ 3597 3598 def parse(self, children=None): 3599 return cde.MindDataNode(self.dataset_file, self.columns_list, self.sampler, self.new_padded_sample, 3600 self.num_padded, shuffle_to_shuffle_mode(self.shuffle_option)) 3601 3602 @check_minddataset 3603 def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None, shuffle=None, num_shards=None, 3604 shard_id=None, sampler=None, padded_sample=None, num_padded=None, num_samples=None, cache=None): 3605 if shuffle is not None and not isinstance(shuffle, (bool, Shuffle)): 3606 raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like 'Shuffle.GLOBAL' or " 3607 "'Shuffle.FILES' or 'Shuffle.INFILE'.") 3608 self.shuffle_option = shuffle 3609 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 3610 shuffle=shuffle_to_bool(shuffle), num_shards=num_shards, shard_id=shard_id, cache=cache) 3611 if isinstance(dataset_file, list): 3612 self.load_dataset = False 3613 else: 3614 self.load_dataset = True 3615 self.dataset_file = dataset_file 3616 self.columns_list = replace_none(columns_list, []) 3617 3618 if shuffle is False: 3619 logger.warning("WARN: global shuffle is not used.") 3620 3621 if sampler is not None: 3622 if isinstance(sampler, ( 3623 samplers.SubsetRandomSampler, samplers.SubsetSampler, samplers.PKSampler, 3624 samplers.DistributedSampler, 3625 samplers.RandomSampler, samplers.SequentialSampler)) is False: 3626 raise ValueError("The sampler is not supported yet.") 3627 3628 self.padded_sample = padded_sample 3629 self.num_padded = replace_none(num_padded, 0) 3630 3631 self.new_padded_sample = {} 3632 if padded_sample: 3633 for k, v in padded_sample.items(): 3634 if isinstance(v, np.ndarray): 3635 self.new_padded_sample[k] = v.tobytes() 3636 else: 3637 self.new_padded_sample[k] = v 3638 3639 3640def _iter_fn(dataset, num_samples): 3641 """ 3642 Generator function wrapper for iterable dataset. 3643 """ 3644 if num_samples is not None and num_samples != 0: 3645 ds_iter = iter(dataset) 3646 for _ in range(num_samples): 3647 try: 3648 val = next(ds_iter) 3649 except StopIteration: 3650 return 3651 # convert output tensors to ndarrays 3652 yield _convert_row(val) 3653 else: 3654 for val in dataset: 3655 # convert output tensors to ndarrays 3656 yield _convert_row(val) 3657 3658 3659def _generator_fn(generator, num_samples): 3660 """ 3661 Generator function wrapper for generator function dataset. 3662 """ 3663 if num_samples is not None and num_samples != 0: 3664 gen_iter = generator() 3665 for _ in range(num_samples): 3666 try: 3667 val = next(gen_iter) 3668 except StopIteration: 3669 return 3670 yield val 3671 else: 3672 gen_iter = generator() 3673 for val in gen_iter: 3674 yield val 3675 3676 3677def _cpp_sampler_fn(sample_ids, dataset): 3678 """ 3679 Generator function wrapper for mappable dataset with cpp sampler. 3680 """ 3681 if not isinstance(sample_ids, np.ndarray): 3682 raise RuntimeError("Sample IDs are not in a numpy array.") 3683 if sample_ids.size == 0: 3684 raise RuntimeError("Sampler passed an empty sample IDs list.") 3685 3686 for i in sample_ids: 3687 val = dataset[i] 3688 # convert output tensors to ndarrays 3689 yield _convert_row(val) 3690 3691 3692def _cpp_sampler_fn_mp(sample_ids, sample_fn): 3693 """ 3694 Multiprocessing generator function wrapper for mappable dataset with cpp sampler. 3695 """ 3696 if not isinstance(sample_ids, np.ndarray): 3697 raise RuntimeError("Sample IDs are not in a numpy array.") 3698 if sample_ids.size == 0: 3699 raise RuntimeError("Sampler passed an empty sample IDs list.") 3700 3701 return sample_fn.process(sample_ids) 3702 3703 3704def _fill_worker_indices(workers, indices, idx): 3705 """ 3706 Worker index queue filler, fill worker index queue in round robin order. 3707 """ 3708 num_worker = len(workers) 3709 while idx < len(indices): 3710 try: 3711 workers[idx % num_worker].put(indices[idx]) 3712 idx += 1 3713 except queue.Full: 3714 break 3715 return idx 3716 3717 3718def _check_shm_usage(num_worker, queue_size, max_rowsize, num_queues=1): 3719 """ 3720 Check sufficient shared memory is available for shared memory queues 3721 when training in parallel mode. 3722 """ 3723 threshold_ratio = 0.8 3724 if platform.system() != "Windows" and _get_device_num() >= 1: 3725 shm_estimate_usage = _get_device_num() * num_worker * num_queues * \ 3726 (queue_size + 2) * max_rowsize * 1024 * 1024 3727 try: 3728 shm_available = psutil.disk_usage('/dev/shm').free 3729 if shm_estimate_usage >= threshold_ratio * shm_available: 3730 raise RuntimeError( 3731 "Insufficient shared memory available. Required: {}, Available: {}. " 3732 "The required memory can't exceed 80% of the available shared memory. " 3733 "Recommend to set_enable_shared_mem to False, reduce max_rowsize or reduce num_parallel_workers." 3734 .format(shm_estimate_usage, shm_available)) 3735 except FileNotFoundError: 3736 raise RuntimeError("Expected /dev/shm to exist.") 3737 3738 3739def _convert_row(row): 3740 """ 3741 Convert Op return value to numpy 3742 """ 3743 value = [] 3744 # convert each column in row into numpy array 3745 for x in row: 3746 if isinstance(x, bytes): # got image bytes from a file 3747 value.append(np.frombuffer(x, np.uint8)) 3748 elif isinstance(x, Tensor): # got mindspore.Tensor 3749 value.append(x.asnumpy()) 3750 else: 3751 value.append(np.array(x, copy=False)) 3752 return tuple(value) 3753 3754 3755class SamplerFn: 3756 """ 3757 Multiprocessing or multithread generator function wrapper master process. 3758 """ 3759 3760 def __init__(self, dataset, num_worker, multi_process, max_rowsize): 3761 self.workers = [] 3762 self.num_worker = num_worker 3763 self.multi_process = multi_process 3764 self.need_join = False 3765 self.ppid = os.getpid() 3766 self.pids = [] 3767 # Event for end of epoch 3768 if multi_process is True: 3769 try: 3770 self.eof = multiprocessing.Event() 3771 except Exception: 3772 raise RuntimeError("Init multiprocessing.Event() failed, This might be caused by insufficient shm," 3773 + " and the recommended shm size is at least 5 GB.") 3774 else: 3775 self.eof = threading.Event() 3776 # Create workers 3777 3778 # get default queue size and adjust queuesize per worker if there are large # workers 3779 queue_size = get_prefetch_size() 3780 queue_size = min(queue_size, queue_size * 4 // num_worker) 3781 queue_size = max(2, queue_size) 3782 3783 if multi_process and get_enable_shared_mem(): 3784 _check_shm_usage(num_worker, queue_size, max_rowsize) 3785 for _ in range(num_worker): 3786 if multi_process is True: 3787 try: 3788 worker = _GeneratorWorkerMp(dataset, self.eof, max_rowsize, queue_size) 3789 except Exception: 3790 raise RuntimeError("Init multiprocessing.Queue() failed, This might be caused by insufficient shm," 3791 + " and the recommended shm size is at least 5 GB.") 3792 worker.daemon = True 3793 # When multi processes fork a subprocess, the lock of the main process is copied to the subprocess, 3794 # which may cause deadlock. Therefore, the subprocess startup is performed in che initialization phase. 3795 # In this phase, the main process is not locked. 3796 worker.start() 3797 self.pids.append(worker.pid) 3798 self.need_join = True 3799 else: 3800 worker = _GeneratorWorkerMt(dataset, self.eof) 3801 worker.daemon = True 3802 self.workers.append(worker) 3803 if multi_process is True and platform.system().lower() != 'windows': 3804 self.eot = threading.Event() 3805 self.watch_dog = threading.Thread(target=_watch_dog, args=(self.eot, self.pids)) 3806 self.watch_dog.daemon = True 3807 self.watch_dog.start() 3808 3809 def process(self, indices): 3810 """ 3811 The main process, start the child process or child thread, and fill the index queue. 3812 Get the result and return. 3813 """ 3814 for w in self.workers: 3815 # Check whether the queue of the subprocess is empty. 3816 if not w.queue_empty(): 3817 raise Exception("The queue of the subprocess is not empty.") 3818 # Start all workers 3819 if not w.is_alive(): 3820 w.start() 3821 3822 # Fill initial index queues 3823 idx_cursor = 0 3824 idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) 3825 3826 # Fetch results 3827 for i in range(len(indices)): 3828 if self.eof.is_set(): 3829 self._stop_subprocess() 3830 return 3831 if self.multi_process is True and not psutil.pid_exists(self.workers[i % self.num_worker].pid): 3832 self._stop_subprocess() 3833 return 3834 # Fetch result and put index 3835 try: 3836 result = self.workers[i % self.num_worker].get() 3837 if isinstance(result, ExceptionHandler): 3838 result.reraise() 3839 except queue.Empty: 3840 self._stop_subprocess() 3841 raise Exception("Generator worker process timeout.") 3842 except KeyboardInterrupt: 3843 self._stop_subprocess() 3844 raise Exception("Generator worker receives KeyboardInterrupt.") 3845 if self.eof.is_set(): 3846 self._stop_subprocess() 3847 return 3848 if idx_cursor < len(indices): 3849 idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) 3850 yield _convert_row(result) 3851 3852 def _stop_subprocess(self): 3853 # Only the main process can call join 3854 if self.need_join is True and self.ppid == os.getpid(): 3855 self.eof.set() 3856 self.need_join = False 3857 for w in self.workers: 3858 if psutil.pid_exists(w.pid): 3859 w.join() 3860 self._abort_watchdog() 3861 3862 def _abort_watchdog(self): 3863 if hasattr(self, 'eot') and self.eot is not None and not self.eot.is_set(): 3864 self.eot.set() 3865 3866 def __del__(self): 3867 self._stop_subprocess() 3868 3869 3870def _subprocess_handle(eof, signum, frame): 3871 threading.Thread(target=eof.set()).start() 3872 3873 3874def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing): 3875 """ 3876 Multithread or multiprocess generator worker process loop. 3877 """ 3878 if is_multiprocessing: 3879 signal.signal(signal.SIGTERM, partial(_subprocess_handle, eof)) 3880 while True: 3881 # Fetch index, block 3882 try: 3883 idx = idx_queue.get(timeout=1) 3884 except KeyboardInterrupt: 3885 if is_multiprocessing: 3886 eof.set() 3887 idx_queue.cancel_join_thread() 3888 result_queue.cancel_join_thread() 3889 raise Exception("Generator worker receives KeyboardInterrupt.") 3890 except queue.Empty: 3891 if eof.is_set(): 3892 if is_multiprocessing: 3893 idx_queue.cancel_join_thread() 3894 result_queue.cancel_join_thread() 3895 return 3896 # If end-of-file (eof) is not set, continue to get data from idx_queue 3897 continue 3898 if idx is None: 3899 # When the queue is out of scope from master process, a None item can be fetched from the queue. 3900 # Upon receiving None, worker process should check if eof is set. 3901 if not eof.is_set(): 3902 raise Exception("") 3903 return 3904 if eof.is_set(): 3905 if is_multiprocessing: 3906 idx_queue.cancel_join_thread() 3907 result_queue.cancel_join_thread() 3908 return 3909 # Fetch data, any exception from __getitem__ will terminate worker and timeout master process 3910 try: 3911 result = dataset[idx] 3912 except Exception: 3913 result = ExceptionHandler(where="in GeneratorDataset worker process") 3914 # Send data, block 3915 while True: 3916 try: 3917 result_queue.put(result, timeout=5) 3918 except KeyboardInterrupt: 3919 if is_multiprocessing: 3920 eof.set() 3921 idx_queue.cancel_join_thread() 3922 result_queue.cancel_join_thread() 3923 raise Exception("Generator worker receives KeyboardInterrupt.") 3924 except queue.Full: 3925 if eof.is_set(): 3926 if is_multiprocessing: 3927 idx_queue.cancel_join_thread() 3928 result_queue.cancel_join_thread() 3929 return 3930 # If eof is not set, continue to put data to result_queue 3931 continue 3932 break 3933 del result, idx 3934 3935 3936class _GeneratorWorkerMt(threading.Thread): 3937 """ 3938 Worker process for multi-thread Generator. 3939 """ 3940 3941 def __init__(self, dataset, eof): 3942 self.idx_queue = queue.Queue(16) 3943 self.res_queue = queue.Queue(16) 3944 super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, False)) 3945 3946 def put(self, item): 3947 """ 3948 Put function for worker index queue. Never block. Raise queue.Full on failure. 3949 """ 3950 self.idx_queue.put_nowait(item) 3951 3952 def get(self): 3953 """ 3954 Get function for worker result queue. Block with timeout. 3955 """ 3956 return self.res_queue.get(timeout=30) 3957 3958 def queue_empty(self): 3959 if not self.idx_queue.empty(): 3960 logger.warning("idx_queue is not empty") 3961 return False 3962 if not self.res_queue.empty(): 3963 logger.warning("res_queue is not empty") 3964 return False 3965 return True 3966 3967 3968class _GeneratorWorkerMp(multiprocessing.Process): 3969 """ 3970 Worker process for multiprocess Generator. 3971 """ 3972 3973 def __init__(self, dataset, eof, max_rowsize, queue_size): 3974 self.idx_queue = multiprocessing.Queue(queue_size) 3975 if get_enable_shared_mem(): 3976 self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize) 3977 else: 3978 self.res_queue = multiprocessing.Queue(queue_size) 3979 self.idx_queue._joincancelled = True # pylint: disable=W0212 3980 self.res_queue._joincancelled = True # pylint: disable=W0212 3981 super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True)) 3982 3983 def put(self, item): 3984 """ 3985 Put function for worker index queue. Never block. Raise queue.Full on failure. 3986 """ 3987 self.idx_queue.put_nowait(item) 3988 3989 def get(self): 3990 """ 3991 Get function for worker result queue. Block with timeout. 3992 """ 3993 # Relax 10s to 30s, since it sometimes will cause "Generator worker process timeout" 3994 # when we run too many iterators with infinite epoch(num_epoch=-1) 3995 return self.res_queue.get(timeout=30) 3996 3997 def queue_empty(self): 3998 if not self.idx_queue.empty(): 3999 logger.warning("idx_queue is not empty.") 4000 return False 4001 if not self.res_queue.empty(): 4002 logger.warning("res_queue is not empty.") 4003 return False 4004 return True 4005 4006 4007class GeneratorDataset(MappableDataset): 4008 """ 4009 A source dataset that generates data from Python by invoking Python data source each epoch. 4010 4011 The column names and column types of generated dataset depend on Python data defined by users. 4012 4013 Args: 4014 source (Union[Callable, Iterable, Random Accessible]): 4015 A generator callable object, an iterable Python object or a random accessible Python object. 4016 Callable source is required to return a tuple of NumPy arrays as a row of the dataset on source().next(). 4017 Iterable source is required to return a tuple of NumPy arrays as a row of the dataset on 4018 iter(source).next(). 4019 Random accessible source is required to return a tuple of NumPy arrays as a row of the dataset on 4020 source[idx]. 4021 column_names (Union[str, list[str]], optional): List of column names of the dataset (default=None). Users are 4022 required to provide either column_names or schema. 4023 column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None). 4024 If provided, sanity check will be performed on generator output. 4025 schema (Union[Schema, str], optional): Path to the JSON schema file or schema object (default=None). Users are 4026 required to provide either column_names or schema. If both are provided, schema will be used. 4027 num_samples (int, optional): The number of samples to be included in the dataset 4028 (default=None, all images). 4029 num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1). 4030 shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. 4031 (default=None, expected order behavior shown in the table). 4032 sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. Random accessible 4033 input is required (default=None, expected order behavior shown in the table). 4034 num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). 4035 Random accessible input is required. When this argument is specified, `num_samples` reflects the maximum 4036 sample number of per shard. 4037 shard_id (int, optional): The shard ID within num_shards (default=None). This argument must be specified only 4038 when num_shards is also specified. Random accessible input is required. 4039 python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This 4040 option could be beneficial if the Python operation is computational heavy (default=True). 4041 max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy 4042 data between processes. This is only used if python_multiprocessing is set to True (default 6 MB). 4043 4044 Raises: 4045 RuntimeError: If source raises an exception during execution. 4046 RuntimeError: If len of column_names does not match output len of source. 4047 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 4048 RuntimeError: If sampler and shuffle are specified at the same time. 4049 RuntimeError: If sampler and sharding are specified at the same time. 4050 RuntimeError: If num_shards is specified but shard_id is None. 4051 RuntimeError: If shard_id is specified but num_shards is None. 4052 ValueError: If shard_id is invalid (< 0 or >= num_shards). 4053 4054 Note: 4055 - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. 4056 The table below shows what input arguments are allowed and their expected behavior. 4057 4058 .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` 4059 :widths: 25 25 50 4060 :header-rows: 1 4061 4062 * - Parameter `sampler` 4063 - Parameter `shuffle` 4064 - Expected Order Behavior 4065 * - None 4066 - None 4067 - random order 4068 * - None 4069 - True 4070 - random order 4071 * - None 4072 - False 4073 - sequential order 4074 * - Sampler object 4075 - None 4076 - order defined by sampler 4077 * - Sampler object 4078 - True 4079 - not allowed 4080 * - Sampler object 4081 - False 4082 - not allowed 4083 4084 Examples: 4085 >>> import numpy as np 4086 >>> 4087 >>> # 1) Multidimensional generator function as callable input. 4088 >>> def generator_multidimensional(): 4089 ... for i in range(64): 4090 ... yield (np.array([[i, i + 1], [i + 2, i + 3]]),) 4091 >>> 4092 >>> dataset = ds.GeneratorDataset(source=generator_multidimensional, column_names=["multi_dimensional_data"]) 4093 >>> 4094 >>> # 2) Multi-column generator function as callable input. 4095 >>> def generator_multi_column(): 4096 ... for i in range(64): 4097 ... yield np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]) 4098 >>> 4099 >>> dataset = ds.GeneratorDataset(source=generator_multi_column, column_names=["col1", "col2"]) 4100 >>> 4101 >>> # 3) Iterable dataset as iterable input. 4102 >>> class MyIterable: 4103 ... def __init__(self): 4104 ... self._index = 0 4105 ... self._data = np.random.sample((5, 2)) 4106 ... self._label = np.random.sample((5, 1)) 4107 ... 4108 ... def __next__(self): 4109 ... if self._index >= len(self._data): 4110 ... raise StopIteration 4111 ... else: 4112 ... item = (self._data[self._index], self._label[self._index]) 4113 ... self._index += 1 4114 ... return item 4115 ... 4116 ... def __iter__(self): 4117 ... self._index = 0 4118 ... return self 4119 ... 4120 ... def __len__(self): 4121 ... return len(self._data) 4122 >>> 4123 >>> dataset = ds.GeneratorDataset(source=MyIterable(), column_names=["data", "label"]) 4124 >>> 4125 >>> # 4) Random accessible dataset as random accessible input. 4126 >>> class MyAccessible: 4127 ... def __init__(self): 4128 ... self._data = np.random.sample((5, 2)) 4129 ... self._label = np.random.sample((5, 1)) 4130 ... 4131 ... def __getitem__(self, index): 4132 ... return self._data[index], self._label[index] 4133 ... 4134 ... def __len__(self): 4135 ... return len(self._data) 4136 >>> 4137 >>> dataset = ds.GeneratorDataset(source=MyAccessible(), column_names=["data", "label"]) 4138 >>> 4139 >>> # list, dict, tuple of Python is also random accessible 4140 >>> dataset = ds.GeneratorDataset(source=[(np.array(0),), (np.array(1),), (np.array(2),)], column_names=["col"]) 4141 """ 4142 4143 @check_generatordataset 4144 def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, 4145 num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, 4146 python_multiprocessing=True, max_rowsize=6): 4147 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 4148 shuffle=shuffle, num_shards=num_shards, shard_id=shard_id) 4149 self.source = source 4150 self.prepared_source = None # source to be sent to C++ 4151 4152 self.python_multiprocessing = python_multiprocessing 4153 4154 self.column_names = to_list(column_names) 4155 4156 if column_types is not None: 4157 self.column_types = mstypelist_to_detypelist(column_types) 4158 else: 4159 self.column_types = [] 4160 4161 self.schema = schema 4162 if schema is not None: 4163 self.schema = schema 4164 if not isinstance(schema, Schema): 4165 self.schema = Schema(schema) 4166 # Move get dataset_size by len from parse to here, because self.source will 4167 # lose attribution of '__len__' after deepcopy. 4168 self.source_len = -1 # unknown 4169 if hasattr(self.source, "__len__"): 4170 self.source_len = len(self.source) 4171 4172 self.max_rowsize = max_rowsize 4173 self.sample_fn = None 4174 4175 def __deepcopy__(self, memodict): 4176 if id(self) in memodict: 4177 return memodict[id(self)] 4178 new_op = self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__")) 4179 4180 sample_fn = None 4181 if new_op.sampler is not None and hasattr(self.source, "__getitem__"): 4182 # The reason why there is a try catch here is because when the new op is being constructed with shared 4183 # memory enabled, there will be an exception thrown if there is not enough shared memory available 4184 if self.source_len == -1: 4185 raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!") 4186 try: 4187 if new_op.num_parallel_workers > 1: 4188 sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing, 4189 self.max_rowsize) 4190 new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) 4191 else: 4192 new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) 4193 new_op.sample_fn = sample_fn 4194 except RuntimeError as e: 4195 raise Exception(str(e)) 4196 else: 4197 try: 4198 new_op.sampler = None 4199 new_op.sample_fn = sample_fn 4200 new_op.source_len = min(new_op.source_len, 4201 new_op.num_samples) if new_op.num_samples != 0 else new_op.source_len 4202 iter(self.source) 4203 except TypeError: 4204 # Use generator function if input callable 4205 new_op.prepared_source = (lambda: _generator_fn(self.source, new_op.num_samples)) 4206 else: 4207 # Use iterator function if input is iterable 4208 # Random accessible input is also iterable 4209 new_op.prepared_source = (lambda: _iter_fn(self.source, new_op.num_samples)) 4210 4211 return new_op 4212 4213 def is_shuffled(self): 4214 return self.sampler.is_shuffled() 4215 4216 def is_sharded(self): 4217 return self.sampler.is_sharded() 4218 4219 def parse(self, children=None): 4220 if self.schema is None: 4221 return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len, 4222 self.sampler, self.num_parallel_workers) 4223 schema = self.schema 4224 if isinstance(schema, Schema): 4225 schema = self.schema.cpp_schema 4226 return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler, 4227 self.num_parallel_workers) 4228 4229 4230class TFRecordDataset(SourceDataset): 4231 """ 4232 A source dataset for reading and parsing datasets stored on disk in TFData format. 4233 4234 The columns of generated dataset depend on the source TFRecord files. 4235 4236 Args: 4237 dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a 4238 pattern of files. The list will be sorted in a lexicographical order. 4239 schema (Union[str, Schema], optional): Path to the JSON schema file or schema object (default=None). 4240 If the schema is not provided, the meta data from the TFData file is considered the schema. 4241 columns_list (list[str], optional): List of columns to be read (default=None, read all columns). 4242 num_samples (int, optional): The number of samples (rows) to be included in the dataset (default=None). 4243 If num_samples is None and numRows(parsed from schema) does not exist, read the full dataset; 4244 If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows; 4245 If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows. 4246 num_parallel_workers (int, optional): Number of workers to read the data 4247 (default=None, number set in the config). 4248 shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch 4249 (default=Shuffle.GLOBAL). 4250 If shuffle is False, no shuffling will be performed; 4251 If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL 4252 Otherwise, there are two levels of shuffling: 4253 4254 - Shuffle.GLOBAL: Shuffle both the files and samples. 4255 4256 - Shuffle.FILES: Shuffle files only. 4257 4258 num_shards (int, optional): Number of shards that the dataset will be divided 4259 into (default=None). When this argument is specified, `num_samples` reflects 4260 the maximum sample number of per shard. 4261 shard_id (int, optional): The shard ID within num_shards (default=None). This 4262 argument can only be specified when num_shards is also specified. 4263 shard_equal_rows (bool, optional): Get equal rows for all shards(default=False). If shard_equal_rows 4264 is false, number of rows of each shard may be not equal, and may lead to a failure in distributed training. 4265 When the number of samples of per TFRecord file are not equal, it is suggested to set to true. 4266 This argument should only be specified when num_shards is also specified. 4267 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 4268 (default=None, which means no cache is used). 4269 4270 Raises: 4271 RuntimeError: If dataset_files are not valid or do not exist. 4272 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 4273 RuntimeError: If num_shards is specified but shard_id is None. 4274 RuntimeError: If shard_id is specified but num_shards is None. 4275 ValueError: If shard_id is invalid (< 0 or >= num_shards). 4276 4277 Examples: 4278 >>> from mindspore import dtype as mstype 4279 >>> 4280 >>> tfrecord_dataset_dir = ["/path/to/tfrecord_dataset_file"] # contains 1 or multiple TFRecord files 4281 >>> tfrecord_schema_file = "/path/to/tfrecord_schema_file" 4282 >>> 4283 >>> # 1) Get all rows from tfrecord_dataset_dir with no explicit schema. 4284 >>> # The meta-data in the first row will be used as a schema. 4285 >>> dataset = ds.TFRecordDataset(dataset_files=tfrecord_dataset_dir) 4286 >>> 4287 >>> # 2) Get all rows from tfrecord_dataset_dir with user-defined schema. 4288 >>> schema = ds.Schema() 4289 >>> schema.add_column(name='col_1d', de_type=mstype.int64, shape=[2]) 4290 >>> dataset = ds.TFRecordDataset(dataset_files=tfrecord_dataset_dir, schema=schema) 4291 >>> 4292 >>> # 3) Get all rows from tfrecord_dataset_dir with schema file. 4293 >>> dataset = ds.TFRecordDataset(dataset_files=tfrecord_dataset_dir, schema=tfrecord_schema_file) 4294 """ 4295 4296 @check_tfrecorddataset 4297 def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, 4298 shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None): 4299 super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, 4300 num_shards=num_shards, shard_id=shard_id, cache=cache) 4301 self.dataset_files = self._find_files(dataset_files) 4302 self.dataset_files.sort() 4303 4304 self.schema = schema 4305 self.columns_list = replace_none(columns_list, []) 4306 self.shard_equal_rows = replace_none(shard_equal_rows, False) 4307 4308 if self.schema is not None and (self.num_samples is None or self.num_samples == 0): 4309 self.num_samples = Schema.get_num_rows(self.schema) 4310 4311 def parse(self, children=None): 4312 schema = self.schema.cpp_schema if isinstance(self.schema, Schema) else self.schema 4313 return cde.TFRecordNode(self.dataset_files, schema, self.columns_list, self.num_samples, self.shuffle_flag, 4314 self.num_shards, self.shard_id, self.shard_equal_rows) 4315 4316 4317class ManifestDataset(MappableDataset): 4318 """ 4319 A source dataset for reading images from a Manifest file. 4320 4321 The generated dataset has two columns: :py:obj:`[image, label]`. 4322 The tensor of column :py:obj:`image` is of the uint8 type. 4323 The tensor of column :py:obj:`label` is of a scalar of uint64 type. 4324 4325 Args: 4326 dataset_file (str): File to be read. 4327 usage (str, optional): Acceptable usages include `train`, `eval` and `inference` (default=`train`). 4328 num_samples (int, optional): The number of images to be included in the dataset. 4329 (default=None, will include all images). 4330 num_parallel_workers (int, optional): Number of workers to read the data 4331 (default=None, will use value set in the config). 4332 shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected 4333 order behavior shown in the table). 4334 sampler (Sampler, optional): Object used to choose samples from the 4335 dataset (default=None, expected order behavior shown in the table). 4336 class_indexing (dict, optional): A str-to-int mapping from label name to index 4337 (default=None, the folder names will be sorted alphabetically and each 4338 class will be given a unique index starting from 0). 4339 decode (bool, optional): decode the images after reading (default=False). 4340 num_shards (int, optional): Number of shards that the dataset will be divided 4341 into (default=None). When this argument is specified, `num_samples` reflects 4342 the max number of samples per shard. 4343 shard_id (int, optional): The shard ID within `num_shards` (default=None). This 4344 argument can only be specified when `num_shards` is also specified. 4345 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 4346 (default=None, which means no cache is used). 4347 4348 Raises: 4349 RuntimeError: If dataset_files are not valid or do not exist. 4350 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 4351 RuntimeError: If sampler and shuffle are specified at the same time. 4352 RuntimeError: If sampler and sharding are specified at the same time. 4353 RuntimeError: If num_shards is specified but shard_id is None. 4354 RuntimeError: If shard_id is specified but num_shards is None. 4355 RuntimeError: If class_indexing is not a dictionary. 4356 ValueError: If shard_id is invalid (< 0 or >= num_shards). 4357 4358 Note: 4359 - The shape of the image column is [image_size] if decode flag is False, or [H,W,C] otherwise. 4360 - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. 4361 The table below shows what input arguments are allowed and their expected behavior. 4362 4363 .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` 4364 :widths: 25 25 50 4365 :header-rows: 1 4366 4367 * - Parameter `sampler` 4368 - Parameter `shuffle` 4369 - Expected Order Behavior 4370 * - None 4371 - None 4372 - random order 4373 * - None 4374 - True 4375 - random order 4376 * - None 4377 - False 4378 - sequential order 4379 * - Sampler object 4380 - None 4381 - order defined by sampler 4382 * - Sampler object 4383 - True 4384 - not allowed 4385 * - Sampler object 4386 - False 4387 - not allowed 4388 4389 Examples: 4390 >>> manifest_dataset_dir = "/path/to/manifest_dataset_file" 4391 >>> 4392 >>> # 1) Read all samples specified in manifest_dataset_dir dataset with 8 threads for training 4393 >>> dataset = ds.ManifestDataset(dataset_file=manifest_dataset_dir, usage="train", num_parallel_workers=8) 4394 >>> 4395 >>> # 2) Read samples (specified in manifest_file.manifest) for shard 0 in a 2-way distributed training setup 4396 >>> dataset = ds.ManifestDataset(dataset_file=manifest_dataset_dir, num_shards=2, shard_id=0) 4397 """ 4398 4399 @check_manifestdataset 4400 def __init__(self, dataset_file, usage="train", num_samples=None, num_parallel_workers=None, shuffle=None, 4401 sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None, cache=None): 4402 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 4403 shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) 4404 4405 self.dataset_file = dataset_file 4406 self.decode = replace_none(decode, False) 4407 self.usage = replace_none(usage, "train") 4408 self.class_indexing = replace_none(class_indexing, {}) 4409 4410 def parse(self, children=None): 4411 return cde.ManifestNode(self.dataset_file, self.usage, self.sampler, self.class_indexing, self.decode) 4412 4413 def get_class_indexing(self): 4414 """ 4415 Get the class index. 4416 4417 Returns: 4418 dict, a str-to-int mapping from label name to index. 4419 4420 Examples: 4421 >>> manifest_dataset_dir = "/path/to/manifest_dataset_file" 4422 >>> 4423 >>> dataset = ds.ManifestDataset(dataset_file=manifest_dataset_dir) 4424 >>> class_indexing = dataset.get_class_indexing() 4425 """ 4426 if self.class_indexing is None or not self.class_indexing: 4427 if self._class_indexing is None: 4428 runtime_getter = self._init_tree_getters() 4429 self._class_indexing = runtime_getter[0].GetClassIndexing() 4430 self.class_indexing = {} 4431 for pair in self._class_indexing: 4432 self.class_indexing[pair[0]] = pair[1][0] 4433 return self.class_indexing 4434 4435 4436class Cifar10Dataset(MappableDataset): 4437 """ 4438 A source dataset for reading and parsing Cifar10 dataset. 4439 This api only supports parsing Cifar10 file in binary version now. 4440 4441 The generated dataset has two columns :py:obj:`[image, label]`. 4442 The tensor of column :py:obj:`image` is of the uint8 type. 4443 The tensor of column :py:obj:`label` is a scalar of the uint32 type. 4444 4445 Args: 4446 dataset_dir (str): Path to the root directory that contains the dataset. 4447 usage (str, optional): Usage of this dataset, can be `train`, `test` or `all` . `train` will read from 50,000 4448 train samples, `test` will read from 10,000 test samples, `all` will read from all 60,000 samples 4449 (default=None, all samples). 4450 num_samples (int, optional): The number of images to be included in the dataset 4451 (default=None, all images). 4452 num_parallel_workers (int, optional): Number of workers to read the data 4453 (default=None, number set in the config). 4454 shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected 4455 order behavior shown in the table). 4456 sampler (Sampler, optional): Object used to choose samples from the 4457 dataset (default=None, expected order behavior shown in the table). 4458 num_shards (int, optional): Number of shards that the dataset will be divided 4459 into (default=None). When this argument is specified, `num_samples` reflects 4460 the maximum sample number of per shard. 4461 shard_id (int, optional): The shard ID within num_shards (default=None). This 4462 argument can only be specified when num_shards is also specified. 4463 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 4464 (default=None, which means no cache is used). 4465 4466 Raises: 4467 RuntimeError: If dataset_dir does not contain data files. 4468 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 4469 RuntimeError: If sampler and shuffle are specified at the same time. 4470 RuntimeError: If sampler and sharding are specified at the same time. 4471 RuntimeError: If num_shards is specified but shard_id is None. 4472 RuntimeError: If shard_id is specified but num_shards is None. 4473 ValueError: If shard_id is invalid (< 0 or >= num_shards). 4474 4475 Note: 4476 - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. 4477 The table below shows what input arguments are allowed and their expected behavior. 4478 4479 .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` 4480 :widths: 25 25 50 4481 :header-rows: 1 4482 4483 * - Parameter `sampler` 4484 - Parameter `shuffle` 4485 - Expected Order Behavior 4486 * - None 4487 - None 4488 - random order 4489 * - None 4490 - True 4491 - random order 4492 * - None 4493 - False 4494 - sequential order 4495 * - Sampler object 4496 - None 4497 - order defined by sampler 4498 * - Sampler object 4499 - True 4500 - not allowed 4501 * - Sampler object 4502 - False 4503 - not allowed 4504 4505 Examples: 4506 >>> cifar10_dataset_dir = "/path/to/cifar10_dataset_directory" 4507 >>> 4508 >>> # 1) Get all samples from CIFAR10 dataset in sequence 4509 >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, shuffle=False) 4510 >>> 4511 >>> # 2) Randomly select 350 samples from CIFAR10 dataset 4512 >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, num_samples=350, shuffle=True) 4513 >>> 4514 >>> # 3) Get samples from CIFAR10 dataset for shard 0 in a 2-way distributed training 4515 >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, num_shards=2, shard_id=0) 4516 >>> 4517 >>> # In CIFAR10 dataset, each dictionary has keys "image" and "label" 4518 4519 About CIFAR-10 dataset: 4520 4521 The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, 4522 with 6000 images per class. There are 50000 training images and 10000 test images. 4523 The 10 different classes represent airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. 4524 4525 Here is the original CIFAR-10 dataset structure. 4526 You can unzip the dataset files into the following directory structure and read by MindSpore's API. 4527 4528 .. code-block:: 4529 4530 . 4531 └── cifar-10-batches-bin 4532 ├── data_batch_1.bin 4533 ├── data_batch_2.bin 4534 ├── data_batch_3.bin 4535 ├── data_batch_4.bin 4536 ├── data_batch_5.bin 4537 ├── test_batch.bin 4538 ├── readme.html 4539 └── batches.meta.txt 4540 4541 Citation: 4542 4543 .. code-block:: 4544 4545 @techreport{Krizhevsky09, 4546 author = {Alex Krizhevsky}, 4547 title = {Learning multiple layers of features from tiny images}, 4548 institution = {}, 4549 year = {2009}, 4550 howpublished = {http://www.cs.toronto.edu/~kriz/cifar.html} 4551 } 4552 """ 4553 4554 @check_mnist_cifar_dataset 4555 def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, 4556 sampler=None, num_shards=None, shard_id=None, cache=None): 4557 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 4558 shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) 4559 4560 self.dataset_dir = dataset_dir 4561 self.usage = replace_none(usage, "all") 4562 4563 def parse(self, children=None): 4564 return cde.Cifar10Node(self.dataset_dir, self.usage, self.sampler) 4565 4566 4567class Cifar100Dataset(MappableDataset): 4568 """ 4569 A source dataset for reading and parsing Cifar100 dataset. 4570 4571 The generated dataset has three columns :py:obj:`[image, coarse_label, fine_label]`. 4572 The tensor of column :py:obj:`image` is of the uint8 type. 4573 The tensor of column :py:obj:`coarse_label` and :py:obj:`fine_labels` are each a scalar of uint32 type. 4574 4575 Args: 4576 dataset_dir (str): Path to the root directory that contains the dataset. 4577 usage (str, optional): Usage of this dataset, can be `train`, `test` or `all` . `train` will read from 50,000 4578 train samples, `test` will read from 10,000 test samples, `all` will read from all 60,000 samples 4579 (default=None, all samples). 4580 num_samples (int, optional): The number of images to be included in the dataset 4581 (default=None, all images). 4582 num_parallel_workers (int, optional): Number of workers to read the data 4583 (default=None, number set in the config). 4584 shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected 4585 order behavior shown in the table). 4586 sampler (Sampler, optional): Object used to choose samples from the 4587 dataset (default=None, expected order behavior shown in the table). 4588 num_shards (int, optional): Number of shards that the dataset will be divided 4589 into (default=None). When this argument is specified, 'num_samples' reflects 4590 the maximum sample number of per shard. 4591 shard_id (int, optional): The shard ID within num_shards (default=None). This 4592 argument can only be specified when num_shards is also specified. 4593 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 4594 (default=None, which means no cache is used). 4595 4596 Raises: 4597 RuntimeError: If dataset_dir does not contain data files. 4598 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 4599 RuntimeError: If sampler and shuffle are specified at the same time. 4600 RuntimeError: If sampler and sharding are specified at the same time. 4601 RuntimeError: If num_shards is specified but shard_id is None. 4602 RuntimeError: If shard_id is specified but num_shards is None. 4603 ValueError: If shard_id is invalid (< 0 or >= num_shards). 4604 4605 Note: 4606 - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. 4607 The table below shows what input arguments are allowed and their expected behavior. 4608 4609 .. list-table:: Expected Order Behavior of Using `sampler` and shuffle 4610 :widths: 25 25 50 4611 :header-rows: 1 4612 4613 * - Parameter `sampler` 4614 - Parameter `shuffle` 4615 - Expected Order Behavior 4616 * - None 4617 - None 4618 - random order 4619 * - None 4620 - True 4621 - random order 4622 * - None 4623 - False 4624 - sequential order 4625 * - Sampler object 4626 - None 4627 - order defined by sampler 4628 * - Sampler object 4629 - True 4630 - not allowed 4631 * - Sampler object 4632 - False 4633 - not allowed 4634 4635 Examples: 4636 >>> cifar100_dataset_dir = "/path/to/cifar100_dataset_directory" 4637 >>> 4638 >>> # 1) Get all samples from CIFAR100 dataset in sequence 4639 >>> dataset = ds.Cifar100Dataset(dataset_dir=cifar100_dataset_dir, shuffle=False) 4640 >>> 4641 >>> # 2) Randomly select 350 samples from CIFAR100 dataset 4642 >>> dataset = ds.Cifar100Dataset(dataset_dir=cifar100_dataset_dir, num_samples=350, shuffle=True) 4643 >>> 4644 >>> # In CIFAR100 dataset, each dictionary has 3 keys: "image", "fine_label" and "coarse_label" 4645 4646 About CIFAR-100 dataset: 4647 4648 This dataset is just like the CIFAR-10, except it has 100 classes containing 600 images 4649 each. There are 500 training images and 100 testing images per class. The 100 classes in 4650 the CIFAR-100 are grouped into 20 superclasses. Each image comes with a "fine" label (the 4651 class to which it belongs) and a "coarse" label (the superclass to which it belongs). 4652 4653 Here is the original CIFAR-100 dataset structure. 4654 You can unzip the dataset files into the following directory structure and read by MindSpore's API. 4655 4656 .. code-block:: 4657 4658 . 4659 └── cifar-100-binary 4660 ├── train.bin 4661 ├── test.bin 4662 ├── fine_label_names.txt 4663 └── coarse_label_names.txt 4664 4665 Citation: 4666 4667 .. code-block:: 4668 4669 @techreport{Krizhevsky09, 4670 author = {Alex Krizhevsky}, 4671 title = {Learning multiple layers of features from tiny images}, 4672 institution = {}, 4673 year = {2009}, 4674 howpublished = {http://www.cs.toronto.edu/~kriz/cifar.html} 4675 } 4676 """ 4677 4678 @check_mnist_cifar_dataset 4679 def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, 4680 sampler=None, num_shards=None, shard_id=None, cache=None): 4681 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 4682 shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) 4683 4684 self.dataset_dir = dataset_dir 4685 self.usage = replace_none(usage, "all") 4686 4687 def parse(self, children=None): 4688 return cde.Cifar100Node(self.dataset_dir, self.usage, self.sampler) 4689 4690 4691class RandomDataset(SourceDataset): 4692 """ 4693 A source dataset that generates random data. 4694 4695 Args: 4696 total_rows (int, optional): Number of samples for the dataset to generate 4697 (default=None, number of samples is random). 4698 schema (Union[str, Schema], optional): Path to the JSON schema file or schema object (default=None). 4699 If the schema is not provided, the random dataset generates a random schema. 4700 columns_list (list[str], optional): List of columns to be read (default=None, read all columns) 4701 num_samples (int, optional): The number of samples to be included in the dataset 4702 (default=None, all samples). 4703 num_parallel_workers (int, optional): Number of workers to read the data 4704 (default=None, number set in the config). 4705 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 4706 (default=None, which means no cache is used). 4707 shuffle (bool, optional): Whether or not to perform shuffle on the dataset 4708 (default=None, expected order behavior shown in the table). 4709 num_shards (int, optional): Number of shards that the dataset will be divided 4710 into (default=None). When this argument is specified, 'num_samples' reflects 4711 the maximum sample number of per shard. 4712 shard_id (int, optional): The shard ID within num_shards (default=None). This 4713 argument can only be specified when num_shards is also specified. 4714 """ 4715 4716 @check_random_dataset 4717 def __init__(self, total_rows=None, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, 4718 cache=None, shuffle=None, num_shards=None, shard_id=None): 4719 super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, 4720 num_shards=num_shards, shard_id=shard_id, cache=cache) 4721 self.total_rows = total_rows 4722 if schema is not None: 4723 self.total_rows = replace_none(total_rows, Schema.get_num_rows(schema)) 4724 self.schema = schema 4725 self.columns_list = replace_none(columns_list, []) 4726 4727 def parse(self, children=None): 4728 schema = self.schema.cpp_schema if isinstance(self.schema, Schema) else self.schema 4729 return cde.RandomNode(self.total_rows, schema, self.columns_list) 4730 4731 4732class Schema: 4733 """ 4734 Class to represent a schema of a dataset. 4735 4736 Args: 4737 schema_file(str): Path of the schema file (default=None). 4738 4739 Returns: 4740 Schema object, schema info about dataset. 4741 4742 Raises: 4743 RuntimeError: If schema file failed to load. 4744 4745 Examples: 4746 >>> from mindspore import dtype as mstype 4747 >>> 4748 >>> # Create schema; specify column name, mindspore.dtype and shape of the column 4749 >>> schema = ds.Schema() 4750 >>> schema.add_column(name='col1', de_type=mstype.int64, shape=[2]) 4751 """ 4752 4753 @check_schema 4754 def __init__(self, schema_file=None): 4755 self.schema_file = replace_none(schema_file, "") 4756 self.cpp_schema = cde.SchemaObj(self.schema_file) 4757 4758 @check_add_column 4759 def add_column(self, name, de_type, shape=None): 4760 """ 4761 Add new column to the schema. 4762 4763 Args: 4764 name (str): The new name of the column. 4765 de_type (str): Data type of the column. 4766 shape (list[int], optional): Shape of the column 4767 (default=None, [-1] which is an unknown shape of rank 1). 4768 4769 Raises: 4770 ValueError: If column type is unknown. 4771 """ 4772 if isinstance(de_type, typing.Type): 4773 de_type = mstype_to_detype(de_type) 4774 col_type = str(de_type) 4775 else: 4776 col_type = str(cde.DataType(de_type)) 4777 if shape is None: 4778 self.cpp_schema.add_column(name, col_type) 4779 else: 4780 self.cpp_schema.add_column(name, col_type, shape) 4781 4782 def parse_columns(self, columns): 4783 """ 4784 Parse the columns and add it to self. 4785 4786 Args: 4787 columns (Union[dict, list[dict], tuple[dict]]): Dataset attribute information, decoded from schema file. 4788 4789 - list[dict], 'name' and 'type' must be in keys, 'shape' optional. 4790 4791 - dict, columns.keys() as name, columns.values() is dict, and 'type' inside, 'shape' optional. 4792 4793 Raises: 4794 RuntimeError: If failed to parse columns. 4795 RuntimeError: If column's name field is missing. 4796 RuntimeError: If column's type field is missing. 4797 4798 Examples: 4799 >>> schema = Schema() 4800 >>> columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]}, 4801 >>> {'name': 'label', 'type': 'int8', 'shape': [1]}] 4802 >>> schema.parse_columns(columns1) 4803 >>> columns2 = {'image': {'shape': [3, 3], 'type': 'int8'}, 'label': {'shape': [1], 'type': 'int8'}} 4804 >>> schema.parse_columns(columns2) 4805 """ 4806 self.cpp_schema.parse_columns(json.dumps(columns, indent=2)) 4807 4808 def to_json(self): 4809 """ 4810 Get a JSON string of the schema. 4811 4812 Returns: 4813 str, JSON string of the schema. 4814 """ 4815 return self.cpp_schema.to_json() 4816 4817 def from_json(self, json_obj): 4818 """ 4819 Get schema file from JSON object. 4820 4821 Args: 4822 json_obj(dictionary): Object of JSON parsed. 4823 4824 Raises: 4825 RuntimeError: if there is unknown item in the object. 4826 RuntimeError: if dataset type is missing in the object. 4827 RuntimeError: if columns are missing in the object. 4828 """ 4829 self.cpp_schema.from_string(json.dumps(json_obj, indent=2)) 4830 4831 def __str__(self): 4832 return self.to_json() 4833 4834 @staticmethod 4835 def get_num_rows(schema): 4836 schema_obj = schema 4837 if not isinstance(schema_obj, Schema): 4838 schema_obj = Schema(schema_obj) 4839 return schema_obj.cpp_schema.get_num_rows() 4840 4841 4842class USPSDataset(SourceDataset): 4843 """ 4844 A source dataset for reading and parsing the USPS dataset. 4845 4846 The generated dataset has two columns: :py:obj:`[image, label]`. 4847 The tensor of column :py:obj:`image` is of the uint8 type. 4848 The tensor of column :py:obj:`label` is of a scalar of uint32 type. 4849 4850 Args: 4851 dataset_dir (str): Path to the root directory that contains the dataset. 4852 usage (str, optional): Usage of this dataset, can be "train", "test" or "all". "train" will read from 7,291 4853 train samples, "test" will read from 2,007 test samples, "all" will read from all 9,298 samples. 4854 (default=None, will read all samples) 4855 num_samples (int, optional): The number of images to be included in the dataset 4856 (default=None, will read all images). 4857 num_parallel_workers (int, optional): Number of workers to read the data 4858 (default=None, will use value set in the config). 4859 shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch 4860 (default=Shuffle.GLOBAL). 4861 If shuffle is False, no shuffling will be performed; 4862 If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL 4863 Otherwise, there are two levels of shuffling: 4864 4865 - Shuffle.GLOBAL: Shuffle both the files and samples. 4866 4867 - Shuffle.FILES: Shuffle files only. 4868 4869 num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). 4870 When this argument is specified, `num_samples` reflects the max sample number of per shard. 4871 shard_id (int, optional): The shard ID within `num_shards` (default=None). This 4872 argument can only be specified when `num_shards` is also specified. 4873 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 4874 (default=None, which means no cache is used). 4875 4876 Raises: 4877 RuntimeError: If dataset_dir is not valid or does not exist or does not contain data files. 4878 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 4879 RuntimeError: If sampler and shuffle are specified at the same time. 4880 RuntimeError: If sampler and sharding are specified at the same time. 4881 RuntimeError: If num_shards is specified but shard_id is None. 4882 RuntimeError: If shard_id is specified but num_shards is None. 4883 ValueError: If usage is invalid. 4884 ValueError: If shard_id is invalid (< 0 or >= num_shards). 4885 4886 Examples: 4887 >>> usps_dataset_dir = "/path/to/usps_dataset_directory" 4888 >>> 4889 >>> # Read 3 samples from USPS dataset 4890 >>> dataset = ds.USPSDataset(dataset_dir=usps_dataset_dir, num_samples=3) 4891 >>> 4892 >>> # Note: In USPS dataset, each dictionary has keys "image" and "label" 4893 4894 About USPS dataset: 4895 4896 USPS is a digit dataset automatically scanned from envelopes by the U.S. Postal Service 4897 containing a total of 9,298 16×16 pixel grayscale samples. 4898 The images are centered, normalized and show a broad range of font styles. 4899 4900 Here is the original USPS dataset structure. 4901 You can download and unzip the dataset files into this directory structure and read by MindSpore's API. 4902 4903 .. code-block:: 4904 . 4905 └── usps_dataset_dir 4906 ├── usps 4907 ├── usps.t 4908 4909 Citation: 4910 4911 .. code-block:: 4912 4913 @article{hull1994database, 4914 title={A database for handwritten text recognition research}, 4915 author={Hull, Jonathan J.}, 4916 journal={IEEE Transactions on pattern analysis and machine intelligence}, 4917 volume={16}, 4918 number={5}, 4919 pages={550--554}, 4920 year={1994}, 4921 publisher={IEEE} 4922 } 4923 """ 4924 4925 @check_usps_dataset 4926 def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, 4927 num_shards=None, shard_id=None, cache=None): 4928 super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, 4929 num_shards=num_shards, shard_id=shard_id, cache=cache) 4930 4931 self.dataset_dir = dataset_dir 4932 self.usage = replace_none(usage, "all") 4933 4934 def parse(self, children=None): 4935 return cde.USPSNode(self.dataset_dir, self.usage, self.num_samples, self.shuffle_flag, self.num_shards, 4936 self.shard_id) 4937 4938 4939class VOCDataset(MappableDataset): 4940 """ 4941 A source dataset for reading and parsing VOC dataset. 4942 4943 The generated dataset with different task setting has different output columns: 4944 4945 - task = :py:obj:`Detection`, output columns: :py:obj:`[image, dtype=uint8]`, :py:obj:`[bbox, dtype=float32]`, \ 4946 :py:obj:`[label, dtype=uint32]`, :py:obj:`[difficult, dtype=uint32]`, :py:obj:`[truncate, dtype=uint32]`. 4947 - task = :py:obj:`Segmentation`, output columns: :py:obj:`[image, dtype=uint8]`, :py:obj:`[target,dtype=uint8]`. 4948 4949 Args: 4950 dataset_dir (str): Path to the root directory that contains the dataset. 4951 task (str, optional): Set the task type of reading voc data, now only support `Segmentation` or `Detection` 4952 (default=`Segmentation`). 4953 usage (str, optional): Set the task type of ImageSets(default=`train`). If task is `Segmentation`, image and 4954 annotation list will be loaded in ./ImageSets/Segmentation/usage + ".txt"; If task is `Detection`, image and 4955 annotation list will be loaded in ./ImageSets/Main/usage + ".txt"; if task and usage are not set, image and 4956 annotation list will be loaded in ./ImageSets/Segmentation/train.txt as default. 4957 class_indexing (dict, optional): A str-to-int mapping from label name to index, only valid in 4958 `Detection` task (default=None, the folder names will be sorted alphabetically and each 4959 class will be given a unique index starting from 0). 4960 num_samples (int, optional): The number of images to be included in the dataset 4961 (default=None, all images). 4962 num_parallel_workers (int, optional): Number of workers to read the data 4963 (default=None, number set in the config). 4964 shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected 4965 order behavior shown in the table). 4966 decode (bool, optional): Decode the images after reading (default=False). 4967 sampler (Sampler, optional): Object used to choose samples from the dataset 4968 (default=None, expected order behavior shown in the table). 4969 num_shards (int, optional): Number of shards that the dataset will be divided 4970 into (default=None). When this argument is specified, `num_samples` reflects 4971 the maximum sample number of per shard. 4972 shard_id (int, optional): The shard ID within num_shards (default=None). This 4973 argument can only be specified when num_shards is also specified. 4974 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 4975 (default=None, which means no cache is used). 4976 extra_metadata(bool, optional): Flag to add extra meta-data to row. If True, an additional column named 4977 :py:obj:`[_meta-filename, dtype=string]` will be output at the end (default=False). 4978 4979 Raises: 4980 RuntimeError: If dataset_dir does not contain data files. 4981 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 4982 RuntimeError: If xml of Annotations is an invalid format. 4983 RuntimeError: If xml of Annotations loss attribution of `object`. 4984 RuntimeError: If xml of Annotations loss attribution of `bndbox`. 4985 RuntimeError: If sampler and shuffle are specified at the same time. 4986 RuntimeError: If sampler and sharding are specified at the same time. 4987 RuntimeError: If num_shards is specified but shard_id is None. 4988 RuntimeError: If shard_id is specified but num_shards is None. 4989 ValueError: If task is not equal 'Segmentation' or 'Detection'. 4990 ValueError: If task equal 'Segmentation' but class_indexing is not None. 4991 ValueError: If txt related to mode is not exist. 4992 ValueError: If shard_id is invalid (< 0 or >= num_shards). 4993 4994 Note: 4995 - Column '[_meta-filename, dtype=string]' won't be output unless an explicit rename dataset op 4996 is added to remove the prefix('_meta-'). 4997 - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. 4998 The table below shows what input arguments are allowed and their expected behavior. 4999 5000 .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` 5001 :widths: 25 25 50 5002 :header-rows: 1 5003 5004 * - Parameter `sampler` 5005 - Parameter `shuffle` 5006 - Expected Order Behavior 5007 * - None 5008 - None 5009 - random order 5010 * - None 5011 - True 5012 - random order 5013 * - None 5014 - False 5015 - sequential order 5016 * - Sampler object 5017 - None 5018 - order defined by sampler 5019 * - Sampler object 5020 - True 5021 - not allowed 5022 * - Sampler object 5023 - False 5024 - not allowed 5025 5026 Examples: 5027 >>> voc_dataset_dir = "/path/to/voc_dataset_directory" 5028 >>> 5029 >>> # 1) Read VOC data for segmentation training 5030 >>> dataset = ds.VOCDataset(dataset_dir=voc_dataset_dir, task="Segmentation", usage="train") 5031 >>> 5032 >>> # 2) Read VOC data for detection training 5033 >>> dataset = ds.VOCDataset(dataset_dir=voc_dataset_dir, task="Detection", usage="train") 5034 >>> 5035 >>> # 3) Read all VOC dataset samples in voc_dataset_dir with 8 threads in random order 5036 >>> dataset = ds.VOCDataset(dataset_dir=voc_dataset_dir, task="Detection", usage="train", 5037 ... num_parallel_workers=8) 5038 >>> 5039 >>> # 4) Read then decode all VOC dataset samples in voc_dataset_dir in sequence 5040 >>> dataset = ds.VOCDataset(dataset_dir=voc_dataset_dir, task="Detection", usage="train", 5041 ... decode=True, shuffle=False) 5042 >>> 5043 >>> # In VOC dataset, if task='Segmentation', each dictionary has keys "image" and "target" 5044 >>> # In VOC dataset, if task='Detection', each dictionary has keys "image" and "annotation" 5045 5046 About VOC dataset. 5047 5048 The PASCAL Visual Object Classes (VOC) challenge is a benchmark in visual 5049 object category recognition and detection, providing the vision and machine 5050 learning communities with a standard dataset of images and annotation, and 5051 standard evaluation procedures. 5052 5053 You can unzip the original VOC-2012 dataset files into this directory structure and read by MindSpore's API. 5054 5055 .. code-block:: 5056 5057 . 5058 └── voc2012_dataset_dir 5059 ├── Annotations 5060 │ ├── 2007_000027.xml 5061 │ ├── 2007_000032.xml 5062 │ ├── ... 5063 ├── ImageSets 5064 │ ├── Action 5065 │ ├── Layout 5066 │ ├── Main 5067 │ └── Segmentation 5068 ├── JPEGImages 5069 │ ├── 2007_000027.jpg 5070 │ ├── 2007_000032.jpg 5071 │ ├── ... 5072 ├── SegmentationClass 5073 │ ├── 2007_000032.png 5074 │ ├── 2007_000033.png 5075 │ ├── ... 5076 └── SegmentationObject 5077 ├── 2007_000032.png 5078 ├── 2007_000033.png 5079 ├── ... 5080 5081 Citation: 5082 5083 .. code-block:: 5084 5085 @article{Everingham10, 5086 author = {Everingham, M. and Van~Gool, L. and Williams, C. K. I. and Winn, J. and Zisserman, A.}, 5087 title = {The Pascal Visual Object Classes (VOC) Challenge}, 5088 journal = {International Journal of Computer Vision}, 5089 volume = {88}, 5090 year = {2012}, 5091 number = {2}, 5092 month = {jun}, 5093 pages = {303--338}, 5094 biburl = {http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham10.html#bibtex}, 5095 howpublished = {http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html} 5096 } 5097 """ 5098 5099 @check_vocdataset 5100 def __init__(self, dataset_dir, task="Segmentation", usage="train", class_indexing=None, num_samples=None, 5101 num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, 5102 cache=None, extra_metadata=False): 5103 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 5104 shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) 5105 self.dataset_dir = dataset_dir 5106 self.task = replace_none(task, "Segmentation") 5107 self.usage = replace_none(usage, "train") 5108 self.class_indexing = replace_none(class_indexing, {}) 5109 self.decode = replace_none(decode, False) 5110 self.extra_metadata = extra_metadata 5111 5112 def parse(self, children=None): 5113 return cde.VOCNode(self.dataset_dir, self.task, self.usage, self.class_indexing, self.decode, self.sampler, 5114 self.extra_metadata) 5115 5116 def get_class_indexing(self): 5117 """ 5118 Get the class index. 5119 5120 Returns: 5121 dict, a str-to-int mapping from label name to index. 5122 5123 Examples: 5124 >>> voc_dataset_dir = "/path/to/voc_dataset_directory" 5125 >>> 5126 >>> dataset = ds.VOCDataset(dataset_dir=voc_dataset_dir) 5127 >>> class_indexing = dataset.get_class_indexing() 5128 """ 5129 if self.task != "Detection": 5130 raise NotImplementedError("Only 'Detection' support get_class_indexing.") 5131 if self.class_indexing is None or not self.class_indexing: 5132 if self._class_indexing is None: 5133 runtime_getter = self._init_tree_getters() 5134 self._class_indexing = runtime_getter[0].GetClassIndexing() 5135 self.class_indexing = {} 5136 for pair in self._class_indexing: 5137 self.class_indexing[pair[0]] = pair[1][0] 5138 return self.class_indexing 5139 5140 5141class CocoDataset(MappableDataset): 5142 """ 5143 A source dataset for reading and parsing COCO dataset. 5144 5145 CocoDataset supports four kinds of tasks, which are Object Detection, Keypoint Detection, Stuff Segmentation and 5146 Panoptic Segmentation of 2017 Train/Val/Test dataset. 5147 5148 The generated dataset with different task setting has different output columns: 5149 5150 - task = :py:obj:`Detection`, output columns: :py:obj:`[image, dtype=uint8]`, :py:obj:`[bbox, dtype=float32]`, \ 5151 :py:obj:`[category_id, dtype=uint32]`, :py:obj:`[iscrowd, dtype=uint32]`. 5152 - task = :py:obj:`Stuff`, output columns: :py:obj:`[image, dtype=uint8]`, :py:obj:`[segmentation,dtype=float32]`, \ 5153 :py:obj:`[iscrowd,dtype=uint32]`. 5154 - task = :py:obj:`Keypoint`, output columns: :py:obj:`[image, dtype=uint8]`, \ 5155 :py:obj:`[keypoints, dtype=float32]`, :py:obj:`[num_keypoints, dtype=uint32]`. 5156 - task = :py:obj:`Panoptic`, output columns: :py:obj:`[image, dtype=uint8]`, :py:obj:`[bbox, dtype=float32]`, \ 5157 :py:obj:`[category_id, dtype=uint32]`, :py:obj:`[iscrowd, dtype=uint32]`, :py:obj:`[area, dtype=uint32]`. 5158 5159 Args: 5160 dataset_dir (str): Path to the root directory that contains the dataset. 5161 annotation_file (str): Path to the annotation JSON file. 5162 task (str, optional): Set the task type for reading COCO data. Supported task types: 5163 `Detection`, `Stuff`, `Panoptic` and `Keypoint` (default=`Detection`). 5164 num_samples (int, optional): The number of images to be included in the dataset 5165 (default=None, all images). 5166 num_parallel_workers (int, optional): Number of workers to read the data 5167 (default=None, number set in the configuration file). 5168 shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected 5169 order behavior shown in the table). 5170 decode (bool, optional): Decode the images after reading (default=False). 5171 sampler (Sampler, optional): Object used to choose samples from the dataset 5172 (default=None, expected order behavior shown in the table). 5173 num_shards (int, optional): Number of shards that the dataset will be divided 5174 into (default=None). When this argument is specified, `num_samples` reflects 5175 the maximum sample number of per shard. 5176 shard_id (int, optional): The shard ID within num_shards (default=None). This 5177 argument can only be specified when num_shards is also specified. 5178 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 5179 (default=None, which means no cache is used). 5180 extra_metadata(bool, optional): Flag to add extra meta-data to row. If True, an additional column will be 5181 output at the end :py:obj:`[_meta-filename, dtype=string]` (default=False). 5182 5183 Raises: 5184 RuntimeError: If dataset_dir does not contain data files. 5185 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 5186 RuntimeError: If sampler and shuffle are specified at the same time. 5187 RuntimeError: If sampler and sharding are specified at the same time. 5188 RuntimeError: If num_shards is specified but shard_id is None. 5189 RuntimeError: If shard_id is specified but num_shards is None. 5190 RuntimeError: If parse JSON file failed. 5191 ValueError: If task is not in [`Detection`, `Stuff`, `Panoptic`, `Keypoint`]. 5192 ValueError: If annotation_file is not exist. 5193 ValueError: If dataset_dir is not exist. 5194 ValueError: If shard_id is invalid (< 0 or >= num_shards). 5195 5196 Note: 5197 - Column '[_meta-filename, dtype=string]' won't be output unless an explicit rename dataset op is added 5198 to remove the prefix('_meta-'). 5199 - CocoDataset doesn't support PKSampler. 5200 - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. 5201 The table below shows what input arguments are allowed and their expected behavior. 5202 5203 .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` 5204 :widths: 25 25 50 5205 :header-rows: 1 5206 5207 * - Parameter `sampler` 5208 - Parameter `shuffle` 5209 - Expected Order Behavior 5210 * - None 5211 - None 5212 - random order 5213 * - None 5214 - True 5215 - random order 5216 * - None 5217 - False 5218 - sequential order 5219 * - Sampler object 5220 - None 5221 - order defined by sampler 5222 * - Sampler object 5223 - True 5224 - not allowed 5225 * - Sampler object 5226 - False 5227 - not allowed 5228 5229 Examples: 5230 >>> coco_dataset_dir = "/path/to/coco_dataset_directory/images" 5231 >>> coco_annotation_file = "/path/to/coco_dataset_directory/annotation_file" 5232 >>> 5233 >>> # 1) Read COCO data for Detection task 5234 >>> dataset = ds.CocoDataset(dataset_dir=coco_dataset_dir, 5235 ... annotation_file=coco_annotation_file, 5236 ... task='Detection') 5237 >>> 5238 >>> # 2) Read COCO data for Stuff task 5239 >>> dataset = ds.CocoDataset(dataset_dir=coco_dataset_dir, 5240 ... annotation_file=coco_annotation_file, 5241 ... task='Stuff') 5242 >>> 5243 >>> # 3) Read COCO data for Panoptic task 5244 >>> dataset = ds.CocoDataset(dataset_dir=coco_dataset_dir, 5245 ... annotation_file=coco_annotation_file, 5246 ... task='Panoptic') 5247 >>> 5248 >>> # 4) Read COCO data for Keypoint task 5249 >>> dataset = ds.CocoDataset(dataset_dir=coco_dataset_dir, 5250 ... annotation_file=coco_annotation_file, 5251 ... task='Keypoint') 5252 >>> 5253 >>> # In COCO dataset, each dictionary has keys "image" and "annotation" 5254 5255 About COCO dataset: 5256 5257 COCO(Microsoft Common Objects in Context) is a large-scale object detection, segmentation, and captioning dataset 5258 with several features: Object segmentation, Recognition in context, Superpixel stuff segmentation, 5259 330K images (>200K labeled), 1.5 million object instances, 80 object categories, 91 stuff categories, 5260 5 captions per image, 250,000 people with keypoints. In contrast to the popular ImageNet dataset, COCO has fewer 5261 categories but more instances in per category. 5262 5263 You can unzip the original COCO-2017 dataset files into this directory structure and read by MindSpore's API. 5264 5265 .. code-block:: 5266 5267 . 5268 └── coco_dataset_directory 5269 ├── train2017 5270 │ ├── 000000000009.jpg 5271 │ ├── 000000000025.jpg 5272 │ ├── ... 5273 ├── test2017 5274 │ ├── 000000000001.jpg 5275 │ ├── 000000058136.jpg 5276 │ ├── ... 5277 ├── val2017 5278 │ ├── 000000000139.jpg 5279 │ ├── 000000057027.jpg 5280 │ ├── ... 5281 └── annotations 5282 ├── captions_train2017.json 5283 ├── captions_val2017.json 5284 ├── instances_train2017.json 5285 ├── instances_val2017.json 5286 ├── person_keypoints_train2017.json 5287 └── person_keypoints_val2017.json 5288 5289 Citation: 5290 5291 .. code-block:: 5292 5293 @article{DBLP:journals/corr/LinMBHPRDZ14, 5294 author = {Tsung{-}Yi Lin and Michael Maire and Serge J. Belongie and 5295 Lubomir D. Bourdev and Ross B. Girshick and James Hays and 5296 Pietro Perona and Deva Ramanan and Piotr Doll{\'{a}}r and C. Lawrence Zitnick}, 5297 title = {Microsoft {COCO:} Common Objects in Context}, 5298 journal = {CoRR}, 5299 volume = {abs/1405.0312}, 5300 year = {2014}, 5301 url = {http://arxiv.org/abs/1405.0312}, 5302 archivePrefix = {arXiv}, 5303 eprint = {1405.0312}, 5304 timestamp = {Mon, 13 Aug 2018 16:48:13 +0200}, 5305 biburl = {https://dblp.org/rec/journals/corr/LinMBHPRDZ14.bib}, 5306 bibsource = {dblp computer science bibliography, https://dblp.org} 5307 } 5308 """ 5309 5310 @check_cocodataset 5311 def __init__(self, dataset_dir, annotation_file, task="Detection", num_samples=None, num_parallel_workers=None, 5312 shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None, 5313 extra_metadata=False): 5314 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 5315 shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) 5316 self.dataset_dir = dataset_dir 5317 self.annotation_file = annotation_file 5318 self.task = replace_none(task, "Detection") 5319 self.decode = replace_none(decode, False) 5320 self.extra_metadata = extra_metadata 5321 5322 def parse(self, children=None): 5323 return cde.CocoNode(self.dataset_dir, self.annotation_file, self.task, self.decode, self.sampler, 5324 self.extra_metadata) 5325 5326 def get_class_indexing(self): 5327 """ 5328 Get the class index. 5329 5330 Returns: 5331 dict, a str-to-list<int> mapping from label name to index. 5332 5333 Examples: 5334 >>> coco_dataset_dir = "/path/to/coco_dataset_directory/images" 5335 >>> coco_annotation_file = "/path/to/coco_dataset_directory/annotation_file" 5336 >>> 5337 >>> # Read COCO data for Detection task 5338 >>> dataset = ds.CocoDataset(dataset_dir=coco_dataset_dir, 5339 ... annotation_file=coco_annotation_file, 5340 ... task='Detection') 5341 >>> 5342 >>> class_indexing = dataset.get_class_indexing() 5343 """ 5344 if self.task not in {"Detection", "Panoptic"}: 5345 raise NotImplementedError("Only 'Detection' and 'Panoptic' support get_class_indexing.") 5346 if self._class_indexing is None: 5347 runtime_getter = self._init_tree_getters() 5348 self._class_indexing = dict(runtime_getter[0].GetClassIndexing()) 5349 return self._class_indexing 5350 5351 5352class CelebADataset(MappableDataset): 5353 """ 5354 A source dataset for reading and parsing CelebA dataset. 5355 Only support to read `list_attr_celeba.txt` currently, which is the attribute annotations of the dataset. 5356 5357 The generated dataset has two columns: :py:obj:`[image, attr]`. 5358 The tensor of column :py:obj:`image` is of the uint8 type. 5359 The tensor of column :py:obj:`attr` is of the uint32 type and one hot encoded. 5360 5361 Args: 5362 dataset_dir (str): Path to the root directory that contains the dataset. 5363 num_parallel_workers (int, optional): Number of workers to read the data (default=None, will use value set in 5364 the config). 5365 shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None). 5366 usage (str, optional): Specify the `train`, `valid`, `test` part or `all` parts of dataset 5367 (default=`all`, will read all samples). 5368 sampler (Sampler, optional): Object used to choose samples from the dataset (default=None). 5369 decode (bool, optional): decode the images after reading (default=False). 5370 extensions (list[str], optional): List of file extensions to be included in the dataset (default=None). 5371 num_samples (int, optional): The number of images to be included in the dataset 5372 (default=None, will include all images). 5373 num_shards (int, optional): Number of shards that the dataset will be divided 5374 into (default=None). When this argument is specified, `num_samples` reflects 5375 the maximum sample number of per shard. 5376 shard_id (int, optional): The shard ID within `num_shards` (default=None). This 5377 argument can only be specified when `num_shards` is also specified. 5378 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 5379 (default=None, which means no cache is used). 5380 5381 Raises: 5382 RuntimeError: If dataset_dir does not contain data files. 5383 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 5384 RuntimeError: If sampler and shuffle are specified at the same time. 5385 RuntimeError: If sampler and sharding are specified at the same time. 5386 RuntimeError: If num_shards is specified but shard_id is None. 5387 RuntimeError: If shard_id is specified but num_shards is None. 5388 ValueError: If shard_id is invalid (< 0 or >= num_shards). 5389 5390 Note: 5391 - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. 5392 The table below shows what input arguments are allowed and their expected behavior. 5393 5394 .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` 5395 :widths: 25 25 50 5396 :header-rows: 1 5397 5398 * - Parameter `sampler` 5399 - Parameter `shuffle` 5400 - Expected Order Behavior 5401 * - None 5402 - None 5403 - random order 5404 * - None 5405 - True 5406 - random order 5407 * - None 5408 - False 5409 - sequential order 5410 * - Sampler object 5411 - None 5412 - order defined by sampler 5413 * - Sampler object 5414 - True 5415 - not allowed 5416 * - Sampler object 5417 - False 5418 - not allowed 5419 5420 Examples: 5421 >>> celeba_dataset_dir = "/path/to/celeba_dataset_directory" 5422 >>> 5423 >>> # Read 5 samples from CelebA dataset 5424 >>> dataset = ds.CelebADataset(dataset_dir=celeba_dataset_dir, usage='train', num_samples=5) 5425 >>> 5426 >>> # Note: In celeba dataset, each data dictionary owns keys "image" and "attr" 5427 5428 About CelebA dataset: 5429 5430 CelebFaces Attributes Dataset (CelebA) is a large-scale face attributes dataset 5431 with more than 200K celebrity images, each with 40 attribute annotations. 5432 5433 The images in this dataset cover large pose variations and background clutter. 5434 CelebA has large diversities, large quantities, and rich annotations, including 5435 5436 * 10,177 number of identities, 5437 * 202,599 number of face images, and 5438 * 5 landmark locations, 40 binary attributes annotations per image. 5439 5440 The dataset can be employed as the training and test sets for the following computer 5441 vision tasks: face attribute recognition, face detection, landmark (or facial part) 5442 localization, and face editing & synthesis. 5443 5444 Original CelebA dataset structure: 5445 5446 .. code-block:: 5447 5448 . 5449 └── CelebA 5450 ├── README.md 5451 ├── Img 5452 │ ├── img_celeba.7z 5453 │ ├── img_align_celeba_png.7z 5454 │ └── img_align_celeba.zip 5455 ├── Eval 5456 │ └── list_eval_partition.txt 5457 └── Anno 5458 ├── list_landmarks_celeba.txt 5459 ├── list_landmarks_align_celeba.txt 5460 ├── list_bbox_celeba.txt 5461 ├── list_attr_celeba.txt 5462 └── identity_CelebA.txt 5463 5464 You can unzip the dataset files into the following structure and read by MindSpore's API. 5465 5466 .. code-block:: 5467 5468 . 5469 └── celeba_dataset_directory 5470 ├── list_attr_celeba.txt 5471 ├── 000001.jpg 5472 ├── 000002.jpg 5473 ├── 000003.jpg 5474 ├── ... 5475 5476 Citation: 5477 5478 .. code-block:: 5479 5480 @article{DBLP:journals/corr/LiuLWT14, 5481 author = {Ziwei Liu and Ping Luo and Xiaogang Wang and Xiaoou Tang}, 5482 title = {Deep Learning Face Attributes in the Wild}, 5483 journal = {CoRR}, 5484 volume = {abs/1411.7766}, 5485 year = {2014}, 5486 url = {http://arxiv.org/abs/1411.7766}, 5487 archivePrefix = {arXiv}, 5488 eprint = {1411.7766}, 5489 timestamp = {Tue, 10 Dec 2019 15:37:26 +0100}, 5490 biburl = {https://dblp.org/rec/journals/corr/LiuLWT14.bib}, 5491 bibsource = {dblp computer science bibliography, https://dblp.org}, 5492 howpublished = {http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html} 5493 } 5494 """ 5495 5496 @check_celebadataset 5497 def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False, 5498 extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None): 5499 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 5500 shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) 5501 self.dataset_dir = dataset_dir 5502 self.decode = replace_none(decode, False) 5503 self.extensions = replace_none(extensions, []) 5504 self.usage = replace_none(usage, "all") 5505 5506 def parse(self, children=None): 5507 if self.usage != "all": 5508 dataset_dir = os.path.realpath(self.dataset_dir) 5509 partition_file = os.path.join(dataset_dir, "list_eval_partition.txt") 5510 if os.path.exists(partition_file) is False: 5511 raise RuntimeError("Partition file can not be found when usage is not 'all'.") 5512 return cde.CelebANode(self.dataset_dir, self.usage, self.sampler, self.decode, self.extensions) 5513 5514 5515class CLUEDataset(SourceDataset): 5516 """ 5517 A source dataset that reads and parses CLUE datasets. 5518 Supported CLUE classification tasks: `AFQMC`, `TNEWS`, `IFLYTEK`, `CMNLI`, `WSC` and `CSL`. 5519 5520 The generated dataset with different task setting has different output columns: 5521 5522 - task = :py:obj:`AFQMC` 5523 - usage = :py:obj:`train`, output columns: :py:obj:`[sentence1, dtype=string]`, \ 5524 :py:obj:`[sentence2, dtype=string]`, :py:obj:`[label, dtype=string]`. 5525 - usage = :py:obj:`test`, output columns: :py:obj:`[id, dtype=uint8]`, \ 5526 :py:obj:`[sentence1, dtype=string]`, :py:obj:`[sentence2, dtype=string]`. 5527 - usage = :py:obj:`eval`, output columns: :py:obj:`[sentence1, dtype=string]`, \ 5528 :py:obj:`[sentence2, dtype=string]`, :py:obj:`[label, dtype=string]`. 5529 5530 - task = :py:obj:`TNEWS` 5531 - usage = :py:obj:`train`, output columns: :py:obj:`[label, dtype=string]`, \ 5532 :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`, :py:obj:`[keywords, dtype=string]`. 5533 - usage = :py:obj:`test`, output columns: :py:obj:`[label, dtype=string]`, \ 5534 :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`, :py:obj:`[keywords, dtype=string]`. 5535 - usage = :py:obj:`eval`, output columns: :py:obj:`[label, dtype=string]`, \ 5536 :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`, :py:obj:`[keywords, dtype=string]`. 5537 5538 - task = :py:obj:`IFLYTEK` 5539 - usage = :py:obj:`train`, output columns: :py:obj:`[label, dtype=string]`, \ 5540 :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`. 5541 - usage = :py:obj:`test`, output columns: :py:obj:`[id, dtype=string]`, \ 5542 :py:obj:`[sentence, dtype=string]`. 5543 - usage = :py:obj:`eval`, output columns: :py:obj:`[label, dtype=string]`, \ 5544 :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`. 5545 5546 - task = :py:obj:`CMNLI` 5547 - usage = :py:obj:`train`, output columns: :py:obj:`[sentence1, dtype=string]`, \ 5548 :py:obj:`[sentence2, dtype=string]`, :py:obj:`[label, dtype=string]`. 5549 - usage = :py:obj:`test`, output columns: :py:obj:`[id, dtype=uint8]`, \ 5550 :py:obj:`[sentence1, dtype=string]`, :py:obj:`[sentence2, dtype=string]`. 5551 - usage = :py:obj:`eval`, output columns: :py:obj:`[sentence1, dtype=string]`, \ 5552 :py:obj:`[sentence2, dtype=string]`, :py:obj:`[label, dtype=string]`. 5553 5554 - task = :py:obj:`WSC` 5555 - usage = :py:obj:`train`, output columns: :py:obj:`[span1_index, dtype=uint8]`, \ 5556 :py:obj:`[span2_index, dtype=uint8]`, :py:obj:`[span1_text, dtype=string]`, \ 5557 :py:obj:`[span2_text, dtype=string]`, :py:obj:`[idx, dtype=uint8]`, \ 5558 :py:obj:`[text, dtype=string]`, :py:obj:`[label, dtype=string]`. 5559 - usage = output columns: :py:obj:`[span1_index, dtype=uint8]`, \ 5560 :py:obj:`[span2_index, dtype=uint8]`, :py:obj:`[span1_text, dtype=string]`, \ 5561 :py:obj:`[span2_text, dtype=string]`, :py:obj:`[idx, dtype=uint8]`, :py:obj:`[text, dtype=string]`. 5562 - usage = :py:obj:`eval`, output columns: :py:obj:`[span1_index, dtype=uint8]`, \ 5563 :py:obj:`[span2_index, dtype=uint8]`, :py:obj:`[span1_text, dtype=string]`, \ 5564 :py:obj:`[span2_text, dtype=string]`, :py:obj:`[idx, dtype=uint8]`, \ 5565 :py:obj:`[text, dtype=string]`, :py:obj:`[label, dtype=string]`. 5566 5567 - task = :py:obj:`CSL` 5568 - usage = :py:obj:`train`, output columns: :py:obj:`[id, dtype=uint8]`, \ 5569 :py:obj:`[abst, dtype=string]`, :py:obj:`[keyword, dtype=string]`, :py:obj:`[label, dtype=string]`. 5570 - usage = :py:obj:`test`, output columns: :py:obj:`[id, dtype=uint8]`, \ 5571 :py:obj:`[abst, dtype=string]`, :py:obj:`[keyword, dtype=string]`. 5572 - usage = :py:obj:`eval`, output columns: :py:obj:`[id, dtype=uint8]`, \ 5573 :py:obj:`[abst, dtype=string]`, :py:obj:`[keyword, dtype=string]`, :py:obj:`[label, dtype=string]`. 5574 5575 Args: 5576 dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for 5577 a pattern of files. The list will be sorted in a lexicographical order. 5578 task (str, optional): The kind of task, one of `AFQMC`, `TNEWS`, `IFLYTEK`, `CMNLI`, `WSC` and `CSL`. 5579 (default=AFQMC). 5580 usage (str, optional): Specify the `train`, `test` or `eval` part of dataset (default="train"). 5581 num_samples (int, optional): The number of samples to be included in the dataset 5582 (default=None, will include all images). 5583 num_parallel_workers (int, optional): Number of workers to read the data 5584 (default=None, number set in the config). 5585 shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch 5586 (default=Shuffle.GLOBAL). 5587 If shuffle is False, no shuffling will be performed; 5588 If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL 5589 Otherwise, there are two levels of shuffling: 5590 5591 - Shuffle.GLOBAL: Shuffle both the files and samples. 5592 5593 - Shuffle.FILES: Shuffle files only. 5594 5595 num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). 5596 When this argument is specified, `num_samples` reflects the maximum sample number of per shard. 5597 shard_id (int, optional): The shard ID within num_shards (default=None). This 5598 argument can only be specified when num_shards is also specified. 5599 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 5600 (default=None, which means no cache is used). 5601 5602 Raises: 5603 RuntimeError: If dataset_files are not valid or do not exist. 5604 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 5605 RuntimeError: If num_shards is specified but shard_id is None. 5606 RuntimeError: If shard_id is specified but num_shards is None. 5607 5608 Examples: 5609 >>> clue_dataset_dir = ["/path/to/clue_dataset_file"] # contains 1 or multiple clue files 5610 >>> dataset = ds.CLUEDataset(dataset_files=clue_dataset_dir, task='AFQMC', usage='train') 5611 5612 About CLUE dataset: 5613 5614 CLUE, a Chinese Language Understanding Evaluation benchmark. It contains multiple 5615 tasks, including single-sentence classification, sentence pair classification, and machine 5616 reading comprehension. 5617 5618 You can unzip the dataset files into the following structure and read by MindSpore's API, 5619 such as afqmc dataset: 5620 5621 .. code-block:: 5622 5623 . 5624 └── afqmc_public 5625 ├── train.json 5626 ├── test.json 5627 └── dev.json 5628 5629 Citation: 5630 5631 .. code-block:: 5632 5633 @article{CLUEbenchmark, 5634 title = {CLUE: A Chinese Language Understanding Evaluation Benchmark}, 5635 author = {Liang Xu, Xuanwei Zhang, Lu Li, Hai Hu, Chenjie Cao, Weitang Liu, Junyi Li, Yudong Li, 5636 Kai Sun, Yechen Xu, Yiming Cui, Cong Yu, Qianqian Dong, Yin Tian, Dian Yu, Bo Shi, Jun Zeng, 5637 Rongzhao Wang, Weijian Xie, Yanting Li, Yina Patterson, Zuoyu Tian, Yiwen Zhang, He Zhou, 5638 Shaoweihua Liu, Qipeng Zhao, Cong Yue, Xinrui Zhang, Zhengliang Yang, Zhenzhong Lan}, 5639 journal = {arXiv preprint arXiv:2004.05986}, 5640 year = {2020}, 5641 howpublished = {https://github.com/CLUEbenchmark/CLUE} 5642 } 5643 """ 5644 5645 @check_cluedataset 5646 def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None, num_parallel_workers=None, 5647 shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): 5648 super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, 5649 num_shards=num_shards, shard_id=shard_id, cache=cache) 5650 self.dataset_files = self._find_files(dataset_files) 5651 self.usage = replace_none(usage, 'train') 5652 self.task = replace_none(task, 'AFQMC') 5653 5654 def parse(self, children=None): 5655 return cde.CLUENode(self.dataset_files, self.task, self.usage, self.num_samples, self.shuffle_flag, 5656 self.num_shards, self.shard_id) 5657 5658 5659class CSVDataset(SourceDataset): 5660 """ 5661 A source dataset that reads and parses comma-separated values (CSV) datasets. 5662 The columns of generated dataset depend on the source CSV files. 5663 5664 Args: 5665 dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search 5666 for a pattern of files. The list will be sorted in a lexicographical order. 5667 field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=','). 5668 column_defaults (list, optional): List of default values for the CSV field (default=None). Each item 5669 in the list is either a valid type (float, int, or string). If this is not provided, treats all 5670 columns as string type. 5671 column_names (list[str], optional): List of column names of the dataset (default=None). If this 5672 is not provided, infers the column_names from the first row of CSV file. 5673 num_samples (int, optional): The number of samples to be included in the dataset 5674 (default=None, will include all images). 5675 num_parallel_workers (int, optional): Number of workers to read the data 5676 (default=None, number set in the config). 5677 shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch 5678 (default=Shuffle.GLOBAL). 5679 If shuffle is False, no shuffling will be performed; 5680 If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL 5681 Otherwise, there are two levels of shuffling: 5682 5683 - Shuffle.GLOBAL: Shuffle both the files and samples. 5684 5685 - Shuffle.FILES: Shuffle files only. 5686 5687 num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). 5688 When this argument is specified, `num_samples` reflects the maximum sample number of per shard. 5689 shard_id (int, optional): The shard ID within num_shards (default=None). This 5690 argument can only be specified when num_shards is also specified. 5691 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 5692 (default=None, which means no cache is used). 5693 5694 Raises: 5695 RuntimeError: If dataset_files are not valid or do not exist. 5696 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 5697 RuntimeError: If num_shards is specified but shard_id is None. 5698 RuntimeError: If shard_id is specified but num_shards is None. 5699 5700 Examples: 5701 >>> csv_dataset_dir = ["/path/to/csv_dataset_file"] # contains 1 or multiple csv files 5702 >>> dataset = ds.CSVDataset(dataset_files=csv_dataset_dir, column_names=['col1', 'col2', 'col3', 'col4']) 5703 """ 5704 5705 @check_csvdataset 5706 def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, 5707 num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): 5708 super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, 5709 num_shards=num_shards, shard_id=shard_id, cache=cache) 5710 self.dataset_files = self._find_files(dataset_files) 5711 self.dataset_files.sort() 5712 self.field_delim = replace_none(field_delim, ',') 5713 self.column_defaults = replace_none(column_defaults, []) 5714 self.column_names = replace_none(column_names, []) 5715 5716 def parse(self, children=None): 5717 return cde.CSVNode(self.dataset_files, self.field_delim, self.column_defaults, self.column_names, 5718 self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id) 5719 5720 5721class SBUDataset(MappableDataset): 5722 """ 5723 A source dataset for reading and parsing the SBU dataset. 5724 5725 The generated dataset has two columns :py:obj:`[image, caption]`. 5726 The tensor of column :py:obj:`image` is of the uint8 type. 5727 The tensor of column :py:obj:`caption` is of the string type. 5728 5729 Args: 5730 dataset_dir (str): Path to the root directory that contains the dataset. 5731 decode (bool, optional): Decode the images after reading (default=False). 5732 num_samples (int, optional): The number of images to be included in the dataset 5733 (default=None, will read all images). 5734 num_parallel_workers (int, optional): Number of workers to read the data 5735 (default=None, will use value set in the config). 5736 shuffle (bool, optional): Whether or not to perform shuffle on the dataset 5737 (default=None, expected order behavior shown in the table). 5738 sampler (Sampler, optional): Object used to choose samples from the 5739 dataset (default=None, expected order behavior shown in the table). 5740 num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). 5741 When this argument is specified, `num_samples` reflects the max sample number of per shard. 5742 shard_id (int, optional): The shard ID within `num_shards` (default=None). This 5743 argument can only be specified when `num_shards` is also specified. 5744 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 5745 (default=None, which means no cache is used). 5746 5747 Raises: 5748 RuntimeError: If dataset_dir does not contain data files. 5749 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 5750 RuntimeError: If sampler and shuffle are specified at the same time. 5751 RuntimeError: If sampler and sharding are specified at the same time. 5752 RuntimeError: If num_shards is specified but shard_id is None. 5753 RuntimeError: If shard_id is specified but num_shards is None. 5754 ValueError: If shard_id is invalid (< 0 or >= num_shards). 5755 5756 Note: 5757 - This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive. 5758 The table below shows what input arguments are allowed and their expected behavior. 5759 5760 .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle' 5761 :widths: 25 25 50 5762 :header-rows: 1 5763 5764 * - Parameter 'sampler' 5765 - Parameter 'shuffle' 5766 - Expected Order Behavior 5767 * - None 5768 - None 5769 - random order 5770 * - None 5771 - True 5772 - random order 5773 * - None 5774 - False 5775 - sequential order 5776 * - Sampler object 5777 - None 5778 - order defined by sampler 5779 * - Sampler object 5780 - True 5781 - not allowed 5782 * - Sampler object 5783 - False 5784 - not allowed 5785 5786 Examples: 5787 >>> sbu_dataset_dir = "/path/to/sbu_dataset_directory" 5788 >>> # Read 3 samples from SBU dataset 5789 >>> dataset = ds.SBUDataset(dataset_dir=sbu_dataset_dir, num_samples=3) 5790 5791 About SBU dataset: 5792 5793 SBU dataset is a large captioned photo collection. 5794 It contains one million images with associated visually relevant captions. 5795 5796 You should manually download the images using official download.m by replacing 'urls{i}(24, end)' with 5797 'urls{i}(24:1:end)' and keep the directory as below. 5798 5799 .. code-block:: 5800 5801 . 5802 └─ dataset_dir 5803 ├── SBU_captioned_photo_dataset_captions.txt 5804 ├── SBU_captioned_photo_dataset_urls.txt 5805 └── sbu_images 5806 ├── m_3326_3596303505_3ce4c20529.jpg 5807 ├── ...... 5808 └── m_2522_4182181099_c3c23ab1cc.jpg 5809 5810 Citation: 5811 5812 .. code-block:: 5813 5814 @inproceedings{Ordonez:2011:im2text, 5815 Author = {Vicente Ordonez and Girish Kulkarni and Tamara L. Berg}, 5816 Title = {Im2Text: Describing Images Using 1 Million Captioned Photographs}, 5817 Booktitle = {Neural Information Processing Systems ({NIPS})}, 5818 Year = {2011}, 5819 } 5820 """ 5821 5822 @check_sbu_dataset 5823 def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, 5824 sampler=None, num_shards=None, shard_id=None, cache=None): 5825 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 5826 shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) 5827 5828 self.dataset_dir = dataset_dir 5829 self.decode = replace_none(decode, False) 5830 5831 def parse(self, children=None): 5832 return cde.SBUNode(self.dataset_dir, self.decode, self.sampler) 5833 5834 5835class _Flowers102Dataset: 5836 """ 5837 Mainly for loading Flowers102 Dataset, and return one row each time. 5838 """ 5839 def __init__(self, dataset_dir, task, usage, decode): 5840 self.dataset_dir = os.path.realpath(dataset_dir) 5841 self.task = task 5842 self.usage = usage 5843 self.decode = decode 5844 5845 if self.task == "Classification": 5846 self.column_names = ["image", "label"] 5847 else: 5848 self.column_names = ["image", "segmentation", "label"] 5849 5850 labels_path = os.path.join(self.dataset_dir, "imagelabels.mat") 5851 setid_path = os.path.join(self.dataset_dir, "setid.mat") 5852 # minus one to transform 1~102 to 0 ~ 101 5853 self.labels = (loadmat(labels_path)["labels"][0] - 1).astype(np.uint32) 5854 self.setid = loadmat(setid_path) 5855 5856 if self.usage == 'train': 5857 self.indices = self.setid["trnid"][0].tolist() 5858 elif self.usage == 'test': 5859 self.indices = self.setid["tstid"][0].tolist() 5860 elif self.usage == 'valid': 5861 self.indices = self.setid["valid"][0].tolist() 5862 elif self.usage == 'all': 5863 self.indices = self.setid["trnid"][0].tolist() 5864 self.indices += self.setid["tstid"][0].tolist() 5865 self.indices += self.setid["valid"][0].tolist() 5866 else: 5867 raise ValueError("Input usage is not within the valid set of ['train', 'valid', 'test', 'all'].") 5868 5869 def __getitem__(self, index): 5870 # range: 1 ~ 8189 5871 image_path = os.path.join(self.dataset_dir, "jpg", "image_" + str(self.indices[index]).zfill(5) + ".jpg") 5872 if not os.path.exists(image_path): 5873 raise RuntimeError("Can not find image file: " + image_path) 5874 5875 if self.decode is True: 5876 image = np.asarray(Image.open(image_path).convert("RGB")) 5877 else: 5878 image = np.fromfile(image_path, dtype=np.uint8) 5879 5880 label = self.labels[self.indices[index] - 1] 5881 5882 if self.task == "Segmentation": 5883 segmentation_path = \ 5884 os.path.join(self.dataset_dir, "segmim", "segmim_" + str(self.indices[index]).zfill(5) + ".jpg") 5885 if not os.path.exists(segmentation_path): 5886 raise RuntimeError("Can not find segmentation file: " + segmentation_path) 5887 if self.decode is True: 5888 segmentation = np.asarray(Image.open(segmentation_path).convert("RGB")) 5889 else: 5890 segmentation = np.fromfile(segmentation_path, dtype=np.uint8) 5891 return image, segmentation, label 5892 5893 return image, label 5894 5895 def __len__(self): 5896 return len(self.indices) 5897 5898 5899class Flowers102Dataset(GeneratorDataset): 5900 """ 5901 A source dataset for reading and parsing Flowers102 dataset. 5902 5903 The generated dataset has two columns :py:obj:`[image, label]` or three :py:obj:`[image, segmentation, label]`. 5904 The tensor of column :py:obj:`image` is of the uint8 type. 5905 The tensor of column :py:obj:`segmentation` is of the uint8 type. 5906 The tensor of column :py:obj:`label` is a scalar or a tensor of the uint32 type. 5907 5908 Args: 5909 dataset_dir (str): Path to the root directory that contains the dataset. 5910 task (str): Specify the 'Classification' or 'Segmentation' task (default='Classification'). 5911 usage (str): Specify the 'train', 'valid', 'test' part or 'all' parts of dataset 5912 (default='all', will read all samples). 5913 num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images). 5914 num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1). 5915 shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. 5916 (default=None, expected order behavior shown in the table). 5917 decode (bool, optional): Whether or not to decode the images and segmentations after reading (default=False). 5918 sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. Random accessible 5919 input is required (default=None, expected order behavior shown in the table). 5920 num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). 5921 Random accessible input is required. When this argument is specified, 'num_samples' reflects the max 5922 sample number of per shard. 5923 shard_id (int, optional): The shard ID within num_shards (default=None). This argument must be specified only 5924 when num_shards is also specified. Random accessible input is required. 5925 5926 Raises: 5927 RuntimeError: If dataset_dir does not contain data files. 5928 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 5929 RuntimeError: If sampler and shuffle are specified at the same time. 5930 RuntimeError: If sampler and sharding are specified at the same time. 5931 RuntimeError: If num_shards is specified but shard_id is None. 5932 RuntimeError: If shard_id is specified but num_shards is None. 5933 ValueError: If shard_id is invalid (< 0 or >= num_shards). 5934 5935 Note: 5936 - This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive. 5937 The table below shows what input arguments are allowed and their expected behavior. 5938 5939 .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle' 5940 :widths: 25 25 50 5941 :header-rows: 1 5942 5943 * - Parameter 'sampler' 5944 - Parameter 'shuffle' 5945 - Expected Order Behavior 5946 * - None 5947 - None 5948 - random order 5949 * - None 5950 - True 5951 - random order 5952 * - None 5953 - False 5954 - sequential order 5955 * - Sampler object 5956 - None 5957 - order defined by sampler 5958 * - Sampler object 5959 - True 5960 - not allowed 5961 * - Sampler object 5962 - False 5963 - not allowed 5964 5965 Examples: 5966 >>> flowers102_dataset_dir = "/path/to/flowers102_dataset_directory" 5967 >>> dataset = ds.Flowers102Dataset(dataset_dir=flowers102_dataset_dir, 5968 ... task="Classification", 5969 ... usage="all", 5970 ... decode=True) 5971 5972 About Flowers102 dataset: 5973 5974 Flowers102 dataset consists of 102 flower categories. 5975 The flowers commonly occur in the United Kingdom. 5976 Each class consists of between 40 and 258 images. 5977 5978 Here is the original Flowers102 dataset structure. 5979 You can unzip the dataset files into this directory structure and read by MindSpore's API. 5980 5981 .. code-block:: 5982 . 5983 └── flowes102_dataset_dir 5984 ├── imagelabels.mat 5985 ├── setid.mat 5986 ├── jpg 5987 ├── image_00001.jpg 5988 ├── image_00002.jpg 5989 ├── ... 5990 ├── segmim 5991 ├── segmim_00001.jpg 5992 ├── segmim_00002.jpg 5993 ├── ... 5994 5995 Citation: 5996 5997 .. code-block:: 5998 5999 @InProceedings{Nilsback08, 6000 author = "Maria-Elena Nilsback and Andrew Zisserman", 6001 title = "Automated Flower Classification over a Large Number of Classes", 6002 booktitle = "Indian Conference on Computer Vision, Graphics and Image Processing", 6003 month = "Dec", 6004 year = "2008", 6005 } 6006 """ 6007 6008 @check_flowers102dataset 6009 def __init__(self, dataset_dir, task="Classification", usage="all", num_samples=None, num_parallel_workers=1, 6010 shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None): 6011 self.dataset_dir = os.path.realpath(dataset_dir) 6012 self.task = replace_none(task, "Classification") 6013 self.usage = replace_none(usage, "all") 6014 self.decode = replace_none(decode, False) 6015 dataset = _Flowers102Dataset(self.dataset_dir, self.task, self.usage, self.decode) 6016 super().__init__(dataset, column_names=dataset.column_names, num_samples=num_samples, 6017 num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, 6018 num_shards=num_shards, shard_id=shard_id) 6019 6020 def get_class_indexing(self): 6021 """ 6022 Get the class index. 6023 6024 Returns: 6025 dict, a str-to-int mapping from label name to index. 6026 """ 6027 class_names = [ 6028 "pink primrose", "hard-leaved pocket orchid", "canterbury bells", 6029 "sweet pea", "english marigold", "tiger lily", "moon orchid", 6030 "bird of paradise", "monkshood", "globe thistle", "snapdragon", 6031 "colt's foot", "king protea", "spear thistle", "yellow iris", 6032 "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", 6033 "giant white arum lily", "fire lily", "pincushion flower", "fritillary", 6034 "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", 6035 "stemless gentian", "artichoke", "sweet william", "carnation", 6036 "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", 6037 "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", 6038 "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", 6039 "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", 6040 "common dandelion", "petunia", "wild pansy", "primula", "sunflower", 6041 "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", 6042 "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", 6043 "black-eyed susan", "silverbush", "californian poppy", "osteospermum", 6044 "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", 6045 "azalea", "water lily", "rose", "thorn apple", "morning glory", 6046 "passion flower", "lotus", "toad lily", "anthurium", "frangipani", 6047 "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", 6048 "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", 6049 "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", 6050 "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", 6051 "blackberry lily" 6052 ] 6053 6054 class_dict = {} 6055 for i, class_name in enumerate(class_names): 6056 class_dict[class_name] = i 6057 6058 return class_dict 6059 6060 6061class TextFileDataset(SourceDataset): 6062 """ 6063 A source dataset that reads and parses datasets stored on disk in text format. 6064 The generated dataset has one column :py:obj:`[text]` with type string. 6065 6066 Args: 6067 dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a 6068 pattern of files. The list will be sorted in a lexicographical order. 6069 num_samples (int, optional): The number of samples to be included in the dataset 6070 (default=None, will include all images). 6071 num_parallel_workers (int, optional): Number of workers to read the data 6072 (default=None, number set in the config). 6073 shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch 6074 (default=Shuffle.GLOBAL). 6075 If shuffle is False, no shuffling will be performed; 6076 If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL 6077 Otherwise, there are two levels of shuffling: 6078 6079 - Shuffle.GLOBAL: Shuffle both the files and samples. 6080 6081 - Shuffle.FILES: Shuffle files only. 6082 6083 num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). 6084 When this argument is specified, `num_samples` reflects the maximum sample number of per shard. 6085 shard_id (int, optional): The shard ID within num_shards (default=None). This 6086 argument can only be specified when num_shards is also specified. 6087 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 6088 (default=None, which means no cache is used). 6089 6090 Raises: 6091 RuntimeError: If dataset_files are not valid or do not exist. 6092 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 6093 RuntimeError: If num_shards is specified but shard_id is None. 6094 RuntimeError: If shard_id is specified but num_shards is None. 6095 6096 Examples: 6097 >>> text_file_dataset_dir = ["/path/to/text_file_dataset_file"] # contains 1 or multiple text files 6098 >>> dataset = ds.TextFileDataset(dataset_files=text_file_dataset_dir) 6099 """ 6100 6101 @check_textfiledataset 6102 def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, 6103 num_shards=None, shard_id=None, cache=None): 6104 super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, 6105 num_shards=num_shards, shard_id=shard_id, cache=cache) 6106 self.dataset_files = self._find_files(dataset_files) 6107 self.dataset_files.sort() 6108 6109 def parse(self, children=None): 6110 return cde.TextFileNode(self.dataset_files, self.num_samples, self.shuffle_flag, self.num_shards, 6111 self.shard_id) 6112 6113 6114class _NumpySlicesDataset: 6115 """ 6116 Mainly for dealing with several kinds of formats of Python data, and return one row each time. 6117 """ 6118 6119 def __init__(self, data, column_list=None): 6120 self.column_list = None 6121 # Convert dict data into tuple 6122 if isinstance(data, dict): 6123 data = self.process_dict(data) 6124 6125 if isinstance(data, tuple): 6126 self.data = () 6127 data_len = len(data) 6128 for i in range(data_len): 6129 self.data = self.data + (np.array(data[i]),) 6130 else: 6131 self.data = (np.array(data),) 6132 6133 # check whether the data length in each column is equal 6134 data_len = [len(data_item) for data_item in self.data] 6135 if data_len[1:] != data_len[:-1]: 6136 raise ValueError("Data length in each column is not equal.") 6137 6138 # Init column_name 6139 if column_list is not None: 6140 self.column_list = column_list 6141 elif self.column_list is None: 6142 self.column_list = [] 6143 column_num = len(self.data) 6144 for i in range(column_num): 6145 self.column_list.append("column_" + str(i)) 6146 6147 def __getitem__(self, index): 6148 data_row = [d[index, ...] for d in self.data] 6149 data_res = tuple(data_row) 6150 return data_res 6151 6152 def __len__(self): 6153 return len(self.data[0]) 6154 6155 def process_dict(self, input_data): 6156 """ 6157 Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first. 6158 """ 6159 # Convert pandas like dict(has "values" column) into General dict 6160 data_keys = list(input_data.keys()) 6161 data_col = input_data[data_keys[0]] 6162 if hasattr(data_col, "values"): 6163 new_dict = {} 6164 for key in data_keys: 6165 item1 = input_data.pop(key) 6166 new_dict[key] = item1.values 6167 input_data = new_dict 6168 6169 # Convert the data in dict into tuple 6170 data = () 6171 keys = list(input_data.keys()) 6172 self.column_list = keys 6173 for key in keys: 6174 value = input_data[key] 6175 data = data + (list(value),) 6176 6177 return data 6178 6179 6180class NumpySlicesDataset(GeneratorDataset): 6181 """ 6182 Creates a dataset with given data slices, mainly for loading Python data into dataset. 6183 6184 The column names and column types of generated dataset depend on Python data defined by users. 6185 6186 Args: 6187 data (Union[list, tuple, dict]) Input of given data. Supported data types include: list, tuple, dict and other 6188 NumPy formats. Input data will be sliced along the first dimension and generate additional rows, if input is 6189 list, there will be one column in each row, otherwise there tends to be multi columns. Large data is not 6190 recommended to be loaded in this way as data is loading into memory. 6191 column_names (list[str], optional): List of column names of the dataset (default=None). If column_names is not 6192 provided, the output column names will be named as the keys of dict when the input data is a dict, 6193 otherwise they will be named like column_0, column_1 ... 6194 num_samples (int, optional): The number of samples to be included in the dataset (default=None, all samples). 6195 num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1). 6196 shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. 6197 (default=None, expected order behavior shown in the table). 6198 sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. Random accessible 6199 input is required (default=None, expected order behavior shown in the table). 6200 num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). 6201 Random accessible input is required. When this argument is specified, `num_samples` reflects the max 6202 sample number of per shard. 6203 shard_id (int, optional): The shard ID within num_shards (default=None). This argument must be specified only 6204 when num_shards is also specified. Random accessible input is required. 6205 6206 Note: 6207 - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. 6208 The table below shows what input arguments are allowed and their expected behavior. 6209 6210 .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` 6211 :widths: 25 25 50 6212 :header-rows: 1 6213 6214 * - Parameter `sampler` 6215 - Parameter `shuffle` 6216 - Expected Order Behavior 6217 * - None 6218 - None 6219 - random order 6220 * - None 6221 - True 6222 - random order 6223 * - None 6224 - False 6225 - sequential order 6226 * - Sampler object 6227 - None 6228 - order defined by sampler 6229 * - Sampler object 6230 - True 6231 - not allowed 6232 * - Sampler object 6233 - False 6234 - not allowed 6235 6236 Raises: 6237 RuntimeError: If len of column_names does not match output len of data. 6238 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 6239 RuntimeError: If sampler and shuffle are specified at the same time. 6240 RuntimeError: If sampler and sharding are specified at the same time. 6241 RuntimeError: If num_shards is specified but shard_id is None. 6242 RuntimeError: If shard_id is specified but num_shards is None. 6243 ValueError: If shard_id is invalid (< 0 or >= num_shards). 6244 6245 Examples: 6246 >>> # 1) Input data can be a list 6247 >>> data = [1, 2, 3] 6248 >>> dataset = ds.NumpySlicesDataset(data=data, column_names=["column_1"]) 6249 >>> 6250 >>> # 2) Input data can be a dictionary, and column_names will be its keys 6251 >>> data = {"a": [1, 2], "b": [3, 4]} 6252 >>> dataset = ds.NumpySlicesDataset(data=data) 6253 >>> 6254 >>> # 3) Input data can be a tuple of lists (or NumPy arrays), each tuple element refers to data in each column 6255 >>> data = ([1, 2], [3, 4], [5, 6]) 6256 >>> dataset = ds.NumpySlicesDataset(data=data, column_names=["column_1", "column_2", "column_3"]) 6257 >>> 6258 >>> # 4) Load data from CSV file 6259 >>> import pandas as pd 6260 >>> df = pd.read_csv(filepath_or_buffer=csv_dataset_dir[0]) 6261 >>> dataset = ds.NumpySlicesDataset(data=dict(df), shuffle=False) 6262 """ 6263 6264 @check_numpyslicesdataset 6265 def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, 6266 num_shards=None, shard_id=None): 6267 dataset = _NumpySlicesDataset(data, column_names) 6268 super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples, 6269 num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, 6270 num_shards=num_shards, shard_id=shard_id) 6271 6272 6273class _PaddedDataset: 6274 """ 6275 Mainly for combining false samples provided by users into a dataset. 6276 6277 Args: 6278 padded_samples (list(dict)): Data provided by user to be added to the initial Dataset. 6279 """ 6280 6281 def __init__(self, padded_samples): 6282 self.column_names = list(padded_samples[0].keys()) 6283 self.padded_samples = padded_samples 6284 6285 def __getitem__(self, item): 6286 return (self.padded_samples[item][key] for key in self.column_names) 6287 6288 def __len__(self): 6289 return len(self.padded_samples) 6290 6291 6292class PaddedDataset(GeneratorDataset): 6293 """ 6294 Creates a dataset with filler data provided by user. Mainly used to add to the original data set 6295 and assign it to the corresponding shard. 6296 6297 Args: 6298 padded_samples (list(dict)): Samples provided by user. 6299 6300 Raises: 6301 TypeError: If padded_samples is not an instance of list. 6302 TypeError: If the element of padded_samples is not an instance of dict. 6303 ValueError: If the padded_samples is empty. 6304 6305 Examples: 6306 >>> import numpy as np 6307 >>> data = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}] 6308 >>> dataset = ds.PaddedDataset(padded_samples=data) 6309 """ 6310 6311 @check_paddeddataset 6312 def __init__(self, padded_samples): 6313 dataset = _PaddedDataset(padded_samples) 6314 super().__init__(dataset, column_names=dataset.column_names, num_shards=None, shard_id=None, shuffle=False) 6315 self._dataset_size = len(dataset.padded_samples) 6316 self.padded_samples = padded_samples 6317 6318 6319class FlickrDataset(MappableDataset): 6320 """ 6321 A source dataset for reading and parsing Flickr8k and Flickr30k dataset. 6322 6323 The generated dataset has two columns :py:obj:`[image, annotation]`. 6324 The tensor of column :py:obj:`image` is of the uint8 type. 6325 The tensor of column :py:obj:`annotation` is a tensor which contains 5 annotations string, 6326 such as ["a", "b", "c", "d", "e"]. 6327 6328 Args: 6329 dataset_dir (str): Path to the root directory that contains the dataset. 6330 annotation_file (str): Path to the root directory that contains the annotation. 6331 num_samples (int, optional): The number of images to be included in the dataset. 6332 (default=None, all images). 6333 num_parallel_workers (int, optional): Number of workers to read the data 6334 (default=None, number set in the config). 6335 shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected 6336 order behavior shown in the table). 6337 decode (bool, optional): Decode the images after reading (default=False). 6338 sampler (Sampler, optional): Object used to choose samples from the 6339 dataset (default=None, expected order behavior shown in the table). 6340 num_shards (int, optional): Number of shards that the dataset will be divided 6341 into (default=None). When this argument is specified, `num_samples` reflects 6342 the max sample number of per shard. 6343 shard_id (int, optional): The shard ID within num_shards (default=None). This 6344 argument can only be specified when num_shards is also specified. 6345 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 6346 (default=None, which means no cache is used). 6347 6348 Raises: 6349 RuntimeError: If dataset_dir is not valid or does not contain data files. 6350 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 6351 RuntimeError: If sampler and shuffle are specified at the same time. 6352 RuntimeError: If sampler and sharding are specified at the same time. 6353 RuntimeError: If num_shards is specified but shard_id is None. 6354 RuntimeError: If shard_id is specified but num_shards is None. 6355 ValueError: If dataset_dir is not exist. 6356 ValueError: If annotation_file is not exist. 6357 ValueError: If shard_id is invalid (< 0 or >= num_shards). 6358 6359 Note: 6360 - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. 6361 The table below shows what input arguments are allowed and their expected behavior. 6362 6363 .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` 6364 :widths: 25 25 50 6365 :header-rows: 1 6366 6367 * - Parameter `sampler` 6368 - Parameter `shuffle` 6369 - Expected Order Behavior 6370 * - None 6371 - None 6372 - random order 6373 * - None 6374 - True 6375 - random order 6376 * - None 6377 - False 6378 - sequential order 6379 * - Sampler object 6380 - None 6381 - order defined by sampler 6382 * - Sampler object 6383 - True 6384 - not allowed 6385 * - Sampler object 6386 - False 6387 - not allowed 6388 6389 Examples: 6390 >>> flickr_dataset_dir = "/path/to/flickr_dataset_directory" 6391 >>> annotation_file = "/path/to/flickr_annotation_file" 6392 >>> 6393 >>> # 1) Get all samples from FLICKR dataset in sequence 6394 >>> dataset = ds.FlickrDataset(dataset_dir=flickr_dataset_dir, 6395 ... annotation_file=annotation_file, 6396 ... shuffle=False) 6397 >>> 6398 >>> # 2) Randomly select 350 samples from FLICKR dataset 6399 >>> dataset = ds.FlickrDataset(dataset_dir=flickr_dataset_dir, 6400 ... annotation_file=annotation_file, 6401 ... num_samples=350, 6402 ... shuffle=True) 6403 >>> 6404 >>> # 3) Get samples from FLICKR dataset for shard 0 in a 2-way distributed training 6405 >>> dataset = ds.FlickrDataset(dataset_dir=flickr_dataset_dir, 6406 ... annotation_file=annotation_file, 6407 ... num_shards=2, 6408 ... shard_id=0) 6409 >>> 6410 >>> # In FLICKR dataset, each dictionary has keys "image" and "annotation" 6411 6412 About Flickr8k dataset: 6413 6414 The Flickr8k dataset consists of 8092 colour images. There are 40460 annotations in the Flickr8k.token.txt, 6415 each image has 5 annotations. 6416 6417 You can unzip the dataset files into the following directory structure and read by MindSpore's API. 6418 6419 .. code-block:: 6420 6421 . 6422 └── Flickr8k 6423 ├── Flickr8k_Dataset 6424 │ ├── 1000268201_693b08cb0e.jpg 6425 │ ├── 1001773457_577c3a7d70.jpg 6426 │ ├── ... 6427 └── Flickr8k.token.txt 6428 6429 Citation: 6430 6431 .. code-block:: 6432 6433 @article{DBLP:journals/jair/HodoshYH13, 6434 author = {Micah Hodosh and Peter Young and Julia Hockenmaier}, 6435 title = {Framing Image Description as a Ranking Task: Data, Models and Evaluation Metrics}, 6436 journal = {J. Artif. Intell. Res.}, 6437 volume = {47}, 6438 pages = {853--899}, 6439 year = {2013}, 6440 url = {https://doi.org/10.1613/jair.3994}, 6441 doi = {10.1613/jair.3994}, 6442 timestamp = {Mon, 21 Jan 2019 15:01:17 +0100}, 6443 biburl = {https://dblp.org/rec/journals/jair/HodoshYH13.bib}, 6444 bibsource = {dblp computer science bibliography, https://dblp.org} 6445 } 6446 6447 About Flickr30k dataset: 6448 6449 The Flickr30k dataset consists of 31783 colour images. There are 158915 annotations in 6450 the results_20130124.token, each image has 5 annotations. 6451 6452 You can unzip the dataset files into the following directory structure and read by MindSpore's API. 6453 6454 Citation: 6455 6456 .. code-block:: 6457 6458 . 6459 └── Flickr30k 6460 ├── flickr30k-images 6461 │ ├── 1000092795.jpg 6462 │ ├── 10002456.jpg 6463 │ ├── ... 6464 └── results_20130124.token 6465 6466 .. code-block:: 6467 6468 @article{DBLP:journals/tacl/YoungLHH14, 6469 author = {Peter Young and Alice Lai and Micah Hodosh and Julia Hockenmaier}, 6470 title = {From image descriptions to visual denotations: New similarity metrics 6471 for semantic inference over event descriptions}, 6472 journal = {Trans. Assoc. Comput. Linguistics}, 6473 volume = {2}, 6474 pages = {67--78}, 6475 year = {2014}, 6476 url = {https://tacl2013.cs.columbia.edu/ojs/index.php/tacl/article/view/229}, 6477 timestamp = {Wed, 17 Feb 2021 21:55:25 +0100}, 6478 biburl = {https://dblp.org/rec/journals/tacl/YoungLHH14.bib}, 6479 bibsource = {dblp computer science bibliography, https://dblp.org} 6480 } 6481 """ 6482 6483 @check_flickr_dataset 6484 def __init__(self, dataset_dir, annotation_file, num_samples=None, num_parallel_workers=None, shuffle=None, 6485 decode=None, sampler=None, num_shards=None, shard_id=None, cache=None): 6486 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 6487 shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) 6488 6489 self.dataset_dir = dataset_dir 6490 self.annotation_file = annotation_file 6491 self.decode = replace_none(decode, False) 6492 6493 def parse(self, children=None): 6494 return cde.FlickrNode(self.dataset_dir, self.annotation_file, self.decode, self.sampler) 6495 6496 6497class SBDataset(GeneratorDataset): 6498 """ 6499 A source dataset for reading and parsing Semantic Boundaries Dataset. 6500 6501 The generated dataset has two columns: :py:obj:`[image, task]`. 6502 The tensor of column :py:obj:`image` is of the uint8 type. 6503 The tensor of column :py:obj:`task` contains 20 images of the uint8 type if `task` is `Boundaries` otherwise 6504 contains 1 image of the uint8 type. 6505 6506 Args: 6507 dataset_dir (str): Path to the root directory that contains the dataset. 6508 task (str, optional): Acceptable tasks include `Boundaries` or `Segmentation` (default=`Boundaries`). 6509 usage (str, optional): Acceptable usages include `train`, `val`, `train_noval` and `all` (default=`all`). 6510 num_samples (int, optional): The number of images to be included in the dataset. 6511 (default=None, all images). 6512 num_parallel_workers (int, optional): Number of workers to read the data 6513 (default=None, number set in the config). 6514 shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected 6515 order behavior shown in the table). 6516 sampler (Sampler, optional): Object used to choose samples from the 6517 dataset (default=None, expected order behavior shown in the table). 6518 num_shards (int, optional): Number of shards that the dataset will be divided 6519 into (default=None). When this argument is specified, `num_samples` reflects 6520 the max sample number of per shard. 6521 shard_id (int, optional): The shard ID within num_shards (default=None). This 6522 argument can only be specified when num_shards is also specified. 6523 6524 Raises: 6525 RuntimeError: If dataset_dir is not valid or does not contain data files. 6526 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 6527 RuntimeError: If sampler and shuffle are specified at the same time. 6528 RuntimeError: If sampler and sharding are specified at the same time. 6529 RuntimeError: If num_shards is specified but shard_id is None. 6530 RuntimeError: If shard_id is specified but num_shards is None. 6531 ValueError: If dataset_dir is not exist. 6532 ValueError: If task is not in [`Boundaries`, `Segmentation`]. 6533 ValueError: If usage is not in [`train`, `val`, `train_noval`, `all`]. 6534 ValueError: If shard_id is invalid (< 0 or >= num_shards). 6535 6536 Note: 6537 - This dataset can take in a sampler. `sampler` and `shuffle` are mutually exclusive. 6538 The table below shows what input arguments are allowed and their expected behavior. 6539 6540 .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` 6541 :widths: 25 25 50 6542 :header-rows: 1 6543 6544 * - Parameter `sampler` 6545 - Parameter `shuffle` 6546 - Expected Order Behavior 6547 * - None 6548 - None 6549 - random order 6550 * - None 6551 - True 6552 - random order 6553 * - None 6554 - False 6555 - sequential order 6556 * - Sampler object 6557 - None 6558 - order defined by sampler 6559 * - Sampler object 6560 - True 6561 - not allowed 6562 * - Sampler object 6563 - False 6564 - not allowed 6565 6566 Examples: 6567 >>> sb_dataset_dir = "/path/to/sb_dataset_directory" 6568 >>> 6569 >>> # 1) Get all samples from Semantic Boundaries Dataset in sequence 6570 >>> dataset = ds.SBDataset(dataset_dir=sb_dataset_dir, shuffle=False) 6571 >>> 6572 >>> # 2) Randomly select 350 samples from Semantic Boundaries Dataset 6573 >>> dataset = ds.SBDataset(dataset_dir=sb_dataset_dir, num_samples=350, shuffle=True) 6574 >>> 6575 >>> # 3) Get samples from Semantic Boundaries Dataset for shard 0 in a 2-way distributed training 6576 >>> dataset = ds.SBDataset(dataset_dir=sb_dataset_dir, num_shards=2, shard_id=0) 6577 >>> 6578 >>> # In Semantic Boundaries Dataset, each dictionary has keys "image" and "task" 6579 6580 About Semantic Boundaries Dataset: 6581 6582 The Semantic Boundaries Dataset consists of 11355 colour images. There are 8498 images' name in the train.txt, 6583 2857 images' name in the val.txt and 5623 images' name in the train_noval.txt. The category cls/ 6584 contains the Segmentation and Boundaries results of category-level, the category inst/ catains the 6585 Segmentation and Boundaries results of instance-level. 6586 6587 You can unzip the dataset files into the following structure and read by MindSpore's API: 6588 6589 .. code-block:: 6590 6591 . 6592 └── benchmark_RELEASE 6593 ├── dataset 6594 ├── img 6595 │ ├── 2008_000002.jpg 6596 │ ├── 2008_000003.jpg 6597 │ ├── ... 6598 ├── cls 6599 │ ├── 2008_000002.mat 6600 │ ├── 2008_000003.mat 6601 │ ├── ... 6602 ├── inst 6603 │ ├── 2008_000002.mat 6604 │ ├── 2008_000003.mat 6605 │ ├── ... 6606 ├── train.txt 6607 └── val.txt 6608 6609 .. code-block:: 6610 6611 @InProceedings{BharathICCV2011, 6612 author = "Bharath Hariharan and Pablo Arbelaez and Lubomir Bourdev and 6613 Subhransu Maji and Jitendra Malik", 6614 title = "Semantic Contours from Inverse Detectors", 6615 booktitle = "International Conference on Computer Vision (ICCV)", 6616 year = "2011", 6617 """ 6618 6619 @check_sb_dataset 6620 def __init__(self, dataset_dir, task='Boundaries', usage='all', num_samples=None, num_parallel_workers=1, 6621 shuffle=None, decode=None, sampler=None, num_shards=None, shard_id=None): 6622 dataset = _SBDataset(dataset_dir, task, usage, decode) 6623 super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples, 6624 num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, 6625 num_shards=num_shards, shard_id=shard_id) 6626 6627 6628class _SBDataset: 6629 """ 6630 Dealing with the data file with .mat extension, and return one row in tuple (image, task) each time. 6631 """ 6632 6633 def __init__(self, dataset_dir, task, usage, decode): 6634 self.column_list = ['image', 'task'] 6635 self.task = task 6636 self.images_path = os.path.join(dataset_dir, 'img') 6637 self.cls_path = os.path.join(dataset_dir, 'cls') 6638 self._loadmat = loadmat 6639 self.categories = 20 6640 self.decode = replace_none(decode, False) 6641 6642 if usage == "all": 6643 image_names = [] 6644 for item in ["train", "val"]: 6645 usage_path = os.path.join(dataset_dir, item + '.txt') 6646 if not os.path.exists(usage_path): 6647 raise FileNotFoundError("SBDataset: {0} not found".format(usage_path)) 6648 with open(usage_path, 'r') as f: 6649 image_names += [x.strip() for x in f.readlines()] 6650 else: 6651 usage_path = os.path.join(dataset_dir, usage + '.txt') 6652 if not os.path.exists(usage_path): 6653 raise FileNotFoundError("SBDataset: {0} not found".format(usage_path)) 6654 with open(usage_path, 'r') as f: 6655 image_names = [x.strip() for x in f.readlines()] 6656 6657 self.images = [os.path.join(self.images_path, i + ".jpg") for i in image_names] 6658 self.clss = [os.path.join(self.cls_path, i + ".mat") for i in image_names] 6659 6660 if len(self.images) != len(self.clss): 6661 raise ValueError("SBDataset: images count not equal to cls count") 6662 6663 self._get_data = self._get_boundaries_data if self.task == "Boundaries" else self._get_segmentation_data 6664 self._get_item = self._get_decode_item if self.decode else self._get_undecode_item 6665 6666 def _get_boundaries_data(self, mat_path): 6667 mat_data = self._loadmat(mat_path) 6668 return np.concatenate([np.expand_dims(mat_data['GTcls'][0][self.task][0][i][0].toarray(), axis=0) 6669 for i in range(self.categories)], axis=0) 6670 6671 def _get_segmentation_data(self, mat_path): 6672 mat_data = self._loadmat(mat_path) 6673 return Image.fromarray(mat_data['GTcls'][0][self.task][0]) 6674 6675 def _get_decode_item(self, idx): 6676 return Image.open(self.images[idx]).convert('RGB'), self._get_data(self.clss[idx]) 6677 6678 def _get_undecode_item(self, idx): 6679 return np.fromfile(self.images[idx], dtype=np.uint8), self._get_data(self.clss[idx]) 6680 6681 def __len__(self): 6682 return len(self.images) 6683 6684 def __getitem__(self, idx): 6685 return self._get_item(idx) 6686 6687 6688class DeserializedDataset(Dataset): 6689 def __init__(self, input_obj): 6690 super().__init__() 6691 self.input_obj = input_obj 6692 6693 def parse(self, children=None): 6694 if isinstance(self.input_obj, dict): 6695 json_str = json.dumps(self.input_obj) 6696 return cde.Dataset.from_json_string(json_str) 6697 return cde.Dataset.from_json_file(self.input_obj) 6698 6699 6700class CityscapesDataset(MappableDataset): 6701 """ 6702 A source dataset for reading and parsing Cityscapes dataset. 6703 6704 The generated dataset has two columns :py:obj:`[image, task]`. 6705 The tensor of column :py:obj:`image` is of the uint8 type. 6706 The tensor of column :py:obj:`task` is of the uint8 type if task is not 'polygon' otherwise task is 6707 a string tensor with serialize json. 6708 6709 Args: 6710 dataset_dir (str): Path to the root directory that contains the dataset. 6711 usage (str): Acceptable usages include `train`, `test`, `val` or `all` if quality_mode is `fine` 6712 otherwise `train`, `train_extra`, `val` or `all` (default=`train`). 6713 quality_mode (str): Acceptable quality_modes include `fine` or `coarse` (default=`fine`). 6714 task (str): Acceptable tasks include `instance`, `semantic`, `polygon` or `color` (default=`instance`). 6715 num_samples (int, optional): The number of images to be included in the dataset. 6716 (default=None, all images). 6717 num_parallel_workers (int, optional): Number of workers to read the data 6718 (default=None, number set in the config). 6719 shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected 6720 order behavior shown in the table). 6721 decode (bool, optional): Decode the images after reading (default=False). 6722 sampler (Sampler, optional): Object used to choose samples from the 6723 dataset (default=None, expected order behavior shown in the table). 6724 num_shards (int, optional): Number of shards that the dataset will be divided 6725 into (default=None). When this argument is specified, `num_samples` reflects 6726 the max sample number of per shard. 6727 shard_id (int, optional): The shard ID within num_shards (default=None). This 6728 argument can only be specified when num_shards is also specified. 6729 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 6730 (default=None, which means no cache is used). 6731 6732 Raises: 6733 RuntimeError: If dataset_dir is invalid or does not contain data files. 6734 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 6735 RuntimeError: If sampler and shuffle are specified at the same time. 6736 RuntimeError: If sampler and sharding are specified at the same time. 6737 RuntimeError: If num_shards is specified but shard_id is None. 6738 RuntimeError: If shard_id is specified but num_shards is None. 6739 ValueError: If dataset_dir is not exist. 6740 ValueError: If task is invalid. 6741 ValueError: If quality_mode is invalid. 6742 ValueError: If usage is invalid. 6743 ValueError: If shard_id is invalid (< 0 or >= num_shards). 6744 6745 Note: 6746 - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. 6747 The table below shows what input arguments are allowed and their expected behavior. 6748 6749 .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` 6750 :widths: 25 25 50 6751 :header-rows: 1 6752 6753 * - Parameter `sampler` 6754 - Parameter `shuffle` 6755 - Expected Order Behavior 6756 * - None 6757 - None 6758 - random order 6759 * - None 6760 - True 6761 - random order 6762 * - None 6763 - False 6764 - sequential order 6765 * - Sampler object 6766 - None 6767 - order defined by sampler 6768 * - Sampler object 6769 - True 6770 - not allowed 6771 * - Sampler object 6772 - False 6773 - not allowed 6774 6775 Examples: 6776 >>> cityscapes_dataset_dir = "/path/to/cityscapes_dataset_directory" 6777 >>> 6778 >>> # 1) Get all samples from Cityscapes dataset in sequence 6779 >>> dataset = ds.CityscapesDataset(dataset_dir=cityscapes_dataset_dir, task="instance", quality_mode="fine", 6780 >>> usage="train", shuffle=False, num_parallel_workers=1) 6781 >>> 6782 >>> # 2) Randomly select 350 samples from Cityscapes dataset 6783 >>> dataset = ds.CityscapesDataset(dataset_dir=cityscapes_dataset_dir, num_samples=350, shuffle=True, 6784 >>> num_parallel_workers=1) 6785 >>> 6786 >>> # 3) Get samples from Cityscapes dataset for shard 0 in a 2-way distributed training 6787 >>> dataset = ds.CityscapesDataset(dataset_dir=cityscapes_dataset_dir, num_shards=2, shard_id=0, 6788 >>> num_parallel_workers=1) 6789 >>> 6790 >>> # In Cityscapes dataset, each dictionary has keys "image" and "task" 6791 6792 About Cityscapes dataset: 6793 6794 The Cityscapes dataset consists of 5000 colour images with high quality dense pixel annotations and 6795 19998 colour images with coarser polygonal annotations in 50 cities. There are 30 classes in this 6796 dataset and the polygonal annotations include dense semantic segmentation and instance segmentation 6797 for vehicle and people. 6798 6799 You can unzip the dataset files into the following directory structure and read by MindSpore's API. 6800 6801 Taking the quality_mode of `fine` as an example. 6802 6803 .. code-block:: 6804 6805 . 6806 └── Cityscapes 6807 ├── leftImg8bit 6808 | ├── train 6809 | | ├── aachen 6810 | | | ├── aachen_000000_000019_leftImg8bit.png 6811 | | | ├── aachen_000001_000019_leftImg8bit.png 6812 | | | ├── ... 6813 | | ├── bochum 6814 | | | ├── ... 6815 | | ├── ... 6816 | ├── test 6817 | | ├── ... 6818 | ├── val 6819 | | ├── ... 6820 └── gtFine 6821 ├── train 6822 | ├── aachen 6823 | | ├── aachen_000000_000019_gtFine_color.png 6824 | | ├── aachen_000000_000019_gtFine_instanceIds.png 6825 | | ├── aachen_000000_000019_gtFine_labelIds.png 6826 | | ├── aachen_000000_000019_gtFine_polygons.json 6827 | | ├── aachen_000001_000019_gtFine_color.png 6828 | | ├── aachen_000001_000019_gtFine_instanceIds.png 6829 | | ├── aachen_000001_000019_gtFine_labelIds.png 6830 | | ├── aachen_000001_000019_gtFine_polygons.json 6831 | | ├── ... 6832 | ├── bochum 6833 | | ├── ... 6834 | ├── ... 6835 ├── test 6836 | ├── ... 6837 └── val 6838 ├── ... 6839 6840 Citation: 6841 6842 .. code-block:: 6843 6844 @inproceedings{Cordts2016Cityscapes, 6845 title = {The Cityscapes Dataset for Semantic Urban Scene Understanding}, 6846 author = {Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, 6847 Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt}, 6848 booktitle = {Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 6849 year = {2016} 6850 } 6851 """ 6852 6853 @check_cityscapes_dataset 6854 def __init__(self, dataset_dir, usage="train", quality_mode="fine", task="instance", num_samples=None, 6855 num_parallel_workers=None, shuffle=None, decode=None, sampler=None, num_shards=None, 6856 shard_id=None, cache=None): 6857 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 6858 shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) 6859 6860 self.dataset_dir = dataset_dir 6861 self.task = task 6862 self.quality_mode = quality_mode 6863 self.usage = usage 6864 self.decode = replace_none(decode, False) 6865 6866 def parse(self, children=None): 6867 return cde.CityscapesNode(self.dataset_dir, self.usage, self.quality_mode, self.task, self.decode, self.sampler) 6868 6869 6870class DIV2KDataset(MappableDataset): 6871 """ 6872 A source dataset for reading and parsing DIV2KDataset dataset. 6873 6874 The generated dataset has two columns :py:obj:`[hr_image, lr_image]`. 6875 The tensor of column :py:obj:`hr_image` is of the uint8 type. 6876 The tensor of column :py:obj:`lr_image` is of the uint8 type. 6877 6878 Args: 6879 dataset_dir (str): Path to the root directory that contains the dataset. 6880 usage (str): Acceptable usages include `train`, `valid` or `all` (default=`train`). 6881 downgrade (str): Acceptable downgrades include `bicubic`, `unknown`, `mild`, `difficult` or 6882 `wild` (default=`bicubic`). 6883 scale (int): Acceptable scales include 2, 3, 4 or 8 (default=2). 6884 When `downgrade` is `bicubic`, scale can be 2, 3, 4, 8. 6885 When `downgrade` is `unknown`, scale can only be 2, 3, 4. 6886 When `downgrade` is `mild`, `difficult` or `wild`, scale can only be 4. 6887 num_samples (int, optional): The number of images to be included in the dataset. 6888 (default=None, all images). 6889 num_parallel_workers (int, optional): Number of workers to read the data 6890 (default=None, number set in the config). 6891 shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected 6892 order behavior shown in the table). 6893 decode (bool, optional): Decode the images after reading (default=False). 6894 sampler (Sampler, optional): Object used to choose samples from the 6895 dataset (default=None, expected order behavior shown in the table). 6896 num_shards (int, optional): Number of shards that the dataset will be divided 6897 into (default=None). When this argument is specified, `num_samples` reflects 6898 the max sample number of per shard. 6899 shard_id (int, optional): The shard ID within num_shards (default=None). This 6900 argument can only be specified when num_shards is also specified. 6901 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. 6902 (default=None, which means no cache is used). 6903 6904 Raises: 6905 RuntimeError: If dataset_dir is invalid or does not contain data files. 6906 RuntimeError: If num_parallel_workers exceeds the max thread numbers. 6907 RuntimeError: If sampler and shuffle are specified at the same time. 6908 RuntimeError: If sampler and sharding are specified at the same time. 6909 RuntimeError: If num_shards is specified but shard_id is None. 6910 RuntimeError: If shard_id is specified but num_shards is None. 6911 ValueError: If dataset_dir is not exist. 6912 ValueError: If usage is invalid. 6913 ValueError: If downgrade is invalid. 6914 ValueError: If scale is invalid. 6915 ValueError: If scale equal to 8 and downgrade not equal to `bicubic`. 6916 ValueError: If downgrade in [`mild`, `difficult`, `wild`] and scale not equal to 4. 6917 ValueError: If shard_id is invalid (< 0 or >= num_shards). 6918 6919 Note: 6920 - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. 6921 The table below shows what input arguments are allowed and their expected behavior. 6922 6923 .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` 6924 :widths: 25 25 50 6925 :header-rows: 1 6926 6927 * - Parameter `sampler` 6928 - Parameter `shuffle` 6929 - Expected Order Behavior 6930 * - None 6931 - None 6932 - random order 6933 * - None 6934 - True 6935 - random order 6936 * - None 6937 - False 6938 - sequential order 6939 * - Sampler object 6940 - None 6941 - order defined by sampler 6942 * - Sampler object 6943 - True 6944 - not allowed 6945 * - Sampler object 6946 - False 6947 - not allowed 6948 6949 Examples: 6950 >>> div2k_dataset_dir = "/path/to/div2k_dataset_directory" 6951 >>> 6952 >>> # 1) Get all samples from DIV2K dataset in sequence 6953 >>> dataset = ds.DIV2KDataset(dataset_dir=div2k_dataset_dir, usage="train", scale=2, downgrade="bicubic", 6954 >>> shuffle=False) 6955 >>> 6956 >>> # 2) Randomly select 350 samples from DIV2K dataset 6957 >>> dataset = ds.DIV2KDataset(dataset_dir=div2k_dataset_dir, usage="train", scale=2, downgrade="bicubic", 6958 >>> num_samples=350, shuffle=True) 6959 >>> 6960 >>> # 3) Get samples from DIV2K dataset for shard 0 in a 2-way distributed training 6961 >>> dataset = ds.DIV2KDataset(dataset_dir=div2k_dataset_dir, usage="train", scale=2, downgrade="bicubic", 6962 >>> num_shards=2, shard_id=0) 6963 >>> 6964 >>> # In DIV2K dataset, each dictionary has keys "hr_image" and "lr_image" 6965 6966 About DIV2K dataset: 6967 6968 The DIV2K dataset consists of 1000 2K resolution images, among which 800 images are for training, 100 images 6969 are for validation and 100 images are for testing. NTIRE 2017 and NTIRE 2018 include only training dataset 6970 and validation dataset. 6971 6972 You can unzip the dataset files into the following directory structure and read by MindSpore's API. 6973 6974 Take the training set as an example. 6975 6976 .. code-block:: 6977 6978 . 6979 └── DIV2K 6980 ├── DIV2K_train_HR 6981 | ├── 0001.png 6982 | ├── 0002.png 6983 | ├── ... 6984 ├── DIV2K_train_LR_bicubic 6985 | ├── X2 6986 | | ├── 0001x2.png 6987 | | ├── 0002x2.png 6988 | | ├── ... 6989 | ├── X3 6990 | | ├── 0001x3.png 6991 | | ├── 0002x3.png 6992 | | ├── ... 6993 | └── X4 6994 | ├── 0001x4.png 6995 | ├── 0002x4.png 6996 | ├── ... 6997 ├── DIV2K_train_LR_unknown 6998 | ├── X2 6999 | | ├── 0001x2.png 7000 | | ├── 0002x2.png 7001 | | ├── ... 7002 | ├── X3 7003 | | ├── 0001x3.png 7004 | | ├── 0002x3.png 7005 | | ├── ... 7006 | └── X4 7007 | ├── 0001x4.png 7008 | ├── 0002x4.png 7009 | ├── ... 7010 ├── DIV2K_train_LR_mild 7011 | ├── 0001x4m.png 7012 | ├── 0002x4m.png 7013 | ├── ... 7014 ├── DIV2K_train_LR_difficult 7015 | ├── 0001x4d.png 7016 | ├── 0002x4d.png 7017 | ├── ... 7018 ├── DIV2K_train_LR_wild 7019 | ├── 0001x4w.png 7020 | ├── 0002x4w.png 7021 | ├── ... 7022 └── DIV2K_train_LR_x8 7023 ├── 0001x8.png 7024 ├── 0002x8.png 7025 ├── ... 7026 Citation: 7027 7028 .. code-block:: 7029 7030 @InProceedings{Agustsson_2017_CVPR_Workshops, 7031 author = {Agustsson, Eirikur and Timofte, Radu}, 7032 title = {NTIRE 2017 Challenge on Single Image Super-Resolution: Dataset and Study}, 7033 booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 7034 url = "http://www.vision.ee.ethz.ch/~timofter/publications/Agustsson-CVPRW-2017.pdf", 7035 month = {July}, 7036 year = {2017} 7037 } 7038 """ 7039 7040 @check_div2k_dataset 7041 def __init__(self, dataset_dir, usage="train", downgrade="bicubic", scale=2, num_samples=None, 7042 num_parallel_workers=None, shuffle=None, decode=None, sampler=None, num_shards=None, 7043 shard_id=None, cache=None): 7044 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 7045 shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) 7046 7047 self.dataset_dir = dataset_dir 7048 self.usage = usage 7049 self.scale = scale 7050 self.downgrade = downgrade 7051 self.decode = replace_none(decode, False) 7052 7053 def parse(self, children=None): 7054 return cde.DIV2KNode(self.dataset_dir, self.usage, self.downgrade, self.scale, self.decode, self.sampler) 7055