1# Copyright 2019-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16The sampler module provides several samplers to generate data from datasets. 17The provided samplers include: DistributedSampler, PKSampler, RandomSampler, 18SequentialSampler, SubsetRandomSampler, and WeightedRandomSampler. 19Users can also define a custom sampler by extending from the Sampler class. 20""" 21 22import numbers 23import numpy as np 24import mindspore._c_dataengine as cde 25import mindspore.dataset as ds 26from ..core import validator_helpers as validator 27 28 29def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): 30 """ 31 Create sampler based on user input. 32 33 Args: 34 num_samples (int): Number of samples. 35 input_sampler (Union[Iterable, Sampler]): Sampler from user. 36 shuffle (bool): Shuffle. 37 num_shards (int): Number of shard for sharding. 38 shard_id (int): Shard ID. 39 40 Returns: 41 Sampler, sampler selected based on user input. 42 """ 43 44 if input_sampler is not None: 45 # If the user provided a sampler, then it doesn't matter what the other args are because 46 # we are being asked specifically to use the given sampler. 47 # That means the following arguments: num_shards, shard_id, shuffle, num_samples should all 48 # be None. Consider this example: 49 # sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle) 50 # data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1) 51 # In this case, the user has given different sample-related arguments that contradict each other. 52 # To prevent this, only allow the user to manually specify the sampler if those arguments are all None 53 if (isinstance(input_sampler, BuiltinSampler) and 54 (any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))): 55 raise ValueError( 56 'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},' 57 ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) 58 if isinstance(input_sampler, BuiltinSampler): 59 return input_sampler 60 if not isinstance(input_sampler, str) and isinstance(input_sampler, (np.ndarray, list, tuple)): 61 return SubsetSampler(input_sampler, num_samples) 62 if not isinstance(input_sampler, str) and validator.is_iterable(input_sampler): 63 # in this case, the user passed in their own sampler object that's not of type BuiltinSampler 64 return IterSampler(input_sampler, num_samples) 65 if isinstance(input_sampler, int): 66 return SubsetSampler([input_sampler]) 67 raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler))) 68 if shuffle is None: 69 if num_shards is not None: 70 # If shuffle is not specified, sharding enabled, use distributed random sampler 71 shuffle = True 72 return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) 73 # If shuffle is not specified, sharding disabled, use random sampler 74 if num_samples is not None and num_samples != 0: 75 return RandomSampler(replacement=True, num_samples=num_samples) 76 return RandomSampler(num_samples=num_samples) 77 if shuffle is True: 78 if num_shards is not None: 79 # If shuffle enabled, sharding enabled, use distributed random sampler 80 return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) 81 # If shuffle enabled, sharding disabled, use random sampler 82 if num_samples is not None: 83 return RandomSampler(replacement=True, num_samples=num_samples) 84 return RandomSampler(num_samples=num_samples) 85 if num_shards is not None: 86 # If shuffle disabled, sharding enabled, use distributed sequential sampler 87 return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) 88 # If shuffle disabled, sharding disabled, use sequential sampler 89 return SequentialSampler(num_samples=num_samples) 90 91 92class BuiltinSampler: 93 """ 94 Base class for BuiltinSampler. 95 96 User should not extend this class. 97 """ 98 99 def __init__(self, num_samples=None): 100 self.child_sampler = None 101 self.num_samples = num_samples 102 103 def parse(self): 104 """ Parse the sampler.""" 105 106 def add_child(self, sampler): 107 """ 108 Add a sub-sampler for given sampler. The parent will receive all data from the 109 output of sub-sampler sampler and apply its sample logic to return new samples. 110 111 Args: 112 sampler (Sampler): Object used to choose samples from the dataset. Only builtin 113 samplers(:class:`mindspore.dataset.DistributedSampler` , 114 :class:`mindspore.dataset.PKSampler`, 115 :class:`mindspore.dataset.RandomSampler`, 116 :class:`mindspore.dataset.SequentialSampler`, 117 :class:`mindspore.dataset.SubsetRandomSampler`, 118 :class:`mindspore.dataset.WeightedRandomSampler` ) are supported. 119 120 Examples: 121 >>> import mindspore.dataset as ds 122 >>> sampler = ds.SequentialSampler(start_index=0, num_samples=3) 123 >>> sampler.add_child(ds.RandomSampler(num_samples=4)) 124 >>> dataset = ds.Cifar10Dataset(cifar10_dataset_dir, sampler=sampler) 125 """ 126 if self.child_sampler is not None: 127 raise RuntimeError("Cannot add child sampler, this sampler already has a child.") 128 self.child_sampler = sampler 129 130 def get_child(self): 131 """ 132 Get the child sampler of given sampler. 133 134 Returns: 135 Sampler, The child sampler of given sampler. 136 137 Examples: 138 >>> import mindspore.dataset as ds 139 >>> sampler = ds.SequentialSampler(start_index=0, num_samples=3) 140 >>> sampler.add_child(ds.RandomSampler(num_samples=2)) 141 >>> child_sampler = sampler.get_child() 142 """ 143 return self.child_sampler 144 145 def parse_child(self): 146 """ Parse the child sampler. """ 147 c_child_sampler = None 148 if self.child_sampler is not None: 149 c_child_sampler = self.child_sampler.parse() 150 return c_child_sampler 151 152 def parse_child_for_minddataset(self): 153 """ Parse the child sampler for MindRecord. """ 154 c_child_sampler = None 155 if self.child_sampler is not None: 156 c_child_sampler = self.child_sampler.parse_for_minddataset() 157 return c_child_sampler 158 159 def is_shuffled(self): 160 """ Not implemented. """ 161 raise NotImplementedError("Sampler must implement is_shuffled.") 162 163 def is_sharded(self): 164 """ Not implemented. """ 165 raise NotImplementedError("Sampler must implement is_sharded.") 166 167 def get_num_samples(self): 168 """ 169 Get `num_samples` value of the current sampler instance. 170 This parameter can be optionally passed in when defining the Sampler. Default: ``None``. 171 This method will return the num_samples value. 172 If the current sampler has child samplers, 173 it will continue to access the child samplers and process the obtained value according to certain rules. 174 175 The following table shows the various possible combinations, and the final results returned. 176 177 .. list-table:: 178 :widths: 25 25 25 25 179 :header-rows: 1 180 181 * - child sampler 182 - num_samples 183 - child_samples 184 - result 185 * - T 186 - x 187 - y 188 - min(x, y) 189 * - T 190 - x 191 - None 192 - x 193 * - T 194 - None 195 - y 196 - y 197 * - T 198 - None 199 - None 200 - None 201 * - None 202 - x 203 - n/a 204 - x 205 * - None 206 - None 207 - n/a 208 - None 209 210 Returns: 211 int, the number of samples, or None. 212 213 Examples: 214 >>> import mindspore.dataset as ds 215 >>> sampler = ds.SequentialSampler(start_index=0, num_samples=3) 216 >>> num_samplers = sampler.get_num_samples() 217 """ 218 if self.child_sampler is not None: 219 child_samples = self.child_sampler.get_num_samples() 220 if self.num_samples is not None: 221 if child_samples is not None: 222 return min(self.num_samples, child_samples) 223 224 return self.num_samples 225 226 return child_samples 227 228 return self.num_samples 229 230 231class Sampler(BuiltinSampler): 232 """ 233 Base class for user defined sampler. 234 A user defined sampler can be used with any existing dataset with sampler support. 235 236 A required _iter_() method should by overridden by the user for sample index generation. 237 An optional reset() method can be overridden for per repeat reset, 238 239 dataset_size and num_samples will be set by dataset once a dataset iterator is created. 240 241 Examples: 242 >>> import mindspore.dataset as ds 243 >>> class ReverseSampler(ds.Sampler): 244 ... def __iter__(self): 245 ... for i in range(self.dataset_size - 1, -1, -1): 246 ... yield i 247 >>> 248 >>> ds = ds.ImageFolderDataset(image_folder_dataset_dir, sampler=ReverseSampler()) 249 """ 250 251 def __init__(self, num_samples=None): 252 super().__init__(num_samples) 253 self.dataset_size = 0 254 self.child_sampler = None 255 self.num_samples = num_samples 256 257 def __iter__(self): 258 """ 259 User defined iterator, must be overridden. 260 _handshake is guaranteed to be called prior to iterator construction. 261 """ 262 raise NotImplementedError 263 264 def reset(self): 265 """ 266 Per repeat reset callback, override this method if necessary 267 """ 268 269 # Initialization handshake callback 270 # Do not override this method! 271 def _handshake(self, ds_size, num_samples): 272 self.dataset_size = ds_size 273 self.num_samples = num_samples 274 275 # Indices fetcher 276 # Do not override this method! 277 # pylint: disable=missing-docstring 278 def _get_indices(self): 279 sampler_iter = iter(self) 280 ret = [] 281 for _ in range(self.num_samples): 282 try: 283 idx = next(sampler_iter) 284 ret.append(idx) 285 except StopIteration: 286 break 287 indices = np.array(ret) 288 if indices.dtype == object: 289 raise RuntimeError("Fetched indices can not be converted to a valid ndarray.") 290 return indices 291 292 # Instance fetcher 293 # Do not override this method! 294 def parse(self): 295 """ Parse the sampler.""" 296 num_samples = self.num_samples if self.num_samples is not None else 0 297 c_sampler = cde.PreBuiltSamplerObj(num_samples, self) 298 c_child_sampler = self.parse_child() 299 c_sampler.add_child(c_child_sampler) 300 return c_sampler 301 302 def add_child(self, sampler): 303 self.child_sampler = sampler 304 305 def get_child(self): 306 return self.child_sampler 307 308 def parse_child(self): 309 c_child_sampler = None 310 if self.child_sampler is not None: 311 c_child_sampler = self.child_sampler.parse() 312 313 return c_child_sampler 314 315 def is_shuffled(self): 316 if self.child_sampler is None: 317 return False 318 319 return self.child_sampler.is_shuffled() 320 321 def is_sharded(self): 322 if self.child_sampler is None: 323 return False 324 325 return self.child_sampler.is_sharded() 326 327 def get_num_samples(self): 328 if self.num_samples is None: 329 return None 330 return self._get_indices().size 331 332 333class DistributedSampler(BuiltinSampler): 334 """ 335 A sampler that accesses a shard of the dataset, it helps divide dataset into multi-subset for distributed training. 336 337 Args: 338 num_shards (int): Number of shards to divide the dataset into. 339 shard_id (int): Shard ID of the current shard, which should within the range of [0, `num_shards` - 1]. 340 shuffle (bool, optional): If True, the indices are shuffled, otherwise it will not be shuffled. 341 Default: ``True``. 342 num_samples (int, optional): The number of samples to draw. Default: ``None``, which means sample all elements. 343 offset(int, optional): The starting shard ID where the elements in the dataset are sent to, which 344 should be no more than `num_shards` . This parameter is only valid when a ConcatDataset takes 345 a :class:`mindspore.dataset.DistributedSampler` as its sampler. It will affect the number of 346 samples of per shard. Default: ``-1``, which means each shard has the same number of samples. 347 348 Raises: 349 TypeError: If `num_shards` is not of type int. 350 TypeError: If `shard_id` is not of type int. 351 TypeError: If `shuffle` is not of type bool. 352 TypeError: If `num_samples` is not of type int. 353 TypeError: If `offset` is not of type int. 354 ValueError: If `num_samples` is a negative value. 355 RuntimeError: If `num_shards` is not a positive value. 356 RuntimeError: If `shard_id` is smaller than 0 or equal to `num_shards` or larger than `num_shards` . 357 RuntimeError: If `offset` is greater than `num_shards` . 358 359 Examples: 360 >>> import mindspore.dataset as ds 361 >>> # creates a distributed sampler with 10 shards in total. This shard is shard 5. 362 >>> sampler = ds.DistributedSampler(10, 5) 363 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, 364 ... num_parallel_workers=8, 365 ... sampler=sampler) 366 """ 367 368 def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1): 369 if not isinstance(num_shards, int): 370 raise TypeError("num_shards must be integer but was: {}.".format(num_shards)) 371 372 if not isinstance(shard_id, int): 373 raise TypeError("shard_id must be integer but was: {}.".format(shard_id)) 374 375 if not isinstance(shuffle, bool): 376 raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle)) 377 378 if num_samples is not None: 379 if not isinstance(num_samples, int): 380 raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) 381 if num_samples < 0 or num_samples > validator.INT64_MAX: 382 raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" 383 .format(0, validator.INT64_MAX)) 384 385 if not isinstance(offset, int): 386 raise TypeError("offset must be integer but was: {}.".format(offset)) 387 388 self.num_shards = num_shards 389 self.shard_id = shard_id 390 self.shuffle = shuffle 391 # get seed in distributed scenario 392 # Example 1. if user set seeds by ds.config.set_seed(4321), then seed 4321 is used 393 # Example 2. if user does not set the seed, then existing or default seed (like 5489) is used 394 self.seed = ds.config.get_seed() 395 self.offset = offset 396 super().__init__(num_samples) 397 398 def parse(self): 399 """ Parse the sampler.""" 400 num_samples = self.num_samples if self.num_samples is not None else 0 401 shuffle = self.shuffle if self.shuffle is not None else True 402 offset = self.offset if self.offset is not None else -1 403 # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle 404 self.seed += 1 405 c_sampler = cde.DistributedSamplerObj(self.num_shards, self.shard_id, 406 shuffle, num_samples, self.seed, offset, True) 407 c_child_sampler = self.parse_child() 408 c_sampler.add_child(c_child_sampler) 409 return c_sampler 410 411 def parse_for_minddataset(self): 412 """ Parse the sampler for MindRecord.""" 413 num_samples = self.num_samples if self.num_samples is not None else 0 414 shuffle = self.shuffle if self.shuffle is not None else True 415 c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, shuffle, 416 self.seed, num_samples, self.offset) 417 c_child_sampler = self.parse_child_for_minddataset() 418 c_sampler.add_child(c_child_sampler) 419 c_sampler.set_num_samples(num_samples) 420 return c_sampler 421 422 def is_shuffled(self): 423 if self.child_sampler is None: 424 return self.shuffle 425 426 return self.child_sampler.is_shuffled() 427 428 def is_sharded(self): 429 if self.child_sampler is None: 430 return self.num_shards > 1 431 432 return self.child_sampler.is_sharded() 433 434 def set_offset(self, offset): 435 self.offset = offset 436 return self 437 438 439class PKSampler(BuiltinSampler): 440 """ 441 Samples K elements for each P class in the dataset. 442 443 Args: 444 num_val (int): Number of elements to sample for each class. 445 num_class (int, optional): Number of classes to sample. Default: ``None`` , sample all classes. 446 The parameter does not support to specify currently. 447 shuffle (bool, optional): Whether to shuffle the class IDs. Default: ``False``. 448 class_column (str, optional): Name of column with class labels for MindDataset. Default: ``'label'``. 449 num_samples (int, optional): The number of samples to draw. Default: ``None`` , which means sample all elements. 450 451 Raises: 452 TypeError: If `shuffle` is not of type bool. 453 TypeError: If `class_column` is not of type str. 454 TypeError: If `num_samples` is not of type int. 455 NotImplementedError: If `num_class` is not ``None``. 456 RuntimeError: If `num_val` is not a positive value. 457 ValueError: If `num_samples` is a negative value. 458 459 Examples: 460 >>> import mindspore.dataset as ds 461 >>> # creates a PKSampler that will get 3 samples from every class. 462 >>> sampler = ds.PKSampler(3) 463 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, 464 ... num_parallel_workers=8, 465 ... sampler=sampler) 466 """ 467 468 def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None): 469 if not isinstance(num_val, int): 470 raise TypeError("num_val must be integer but was: {}.".format(num_val)) 471 472 if num_class is not None: 473 raise NotImplementedError("Not supported to specify num_class for PKSampler.") 474 475 if not isinstance(shuffle, bool): 476 raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle)) 477 478 if not isinstance(class_column, str): 479 raise TypeError("class_column must be a str value but was: {}.".format(class_column)) 480 481 if num_samples is not None: 482 if not isinstance(num_samples, int): 483 raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) 484 if num_samples < 0 or num_samples > validator.INT64_MAX: 485 raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" 486 .format(0, validator.INT64_MAX)) 487 488 self.num_val = num_val 489 self.shuffle = shuffle 490 self.class_column = class_column # work for minddataset 491 super().__init__(num_samples) 492 493 def parse(self): 494 """ Parse the sampler.""" 495 num_samples = self.num_samples if self.num_samples is not None else 0 496 shuffle = self.shuffle if self.shuffle is not None else False 497 c_sampler = cde.PKSamplerObj(self.num_val, shuffle, num_samples) 498 c_child_sampler = self.parse_child() 499 c_sampler.add_child(c_child_sampler) 500 return c_sampler 501 502 def is_shuffled(self): 503 if self.child_sampler is None: 504 return self.shuffle 505 506 return self.child_sampler.is_shuffled() 507 508 def is_sharded(self): 509 if self.child_sampler is None: 510 return False 511 512 return self.child_sampler.is_sharded() 513 514 def parse_for_minddataset(self): 515 """Parse the sampler for MindRecord.""" 516 if not self.class_column or not isinstance(self.class_column, str): 517 raise ValueError("class_column should be a not empty string value, \ 518 but got class_column: {}.".format(self.class_column)) 519 num_samples = self.num_samples if self.num_samples is not None else 0 520 c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples) 521 c_child_sampler = self.parse_child_for_minddataset() 522 c_sampler.add_child(c_child_sampler) 523 c_sampler.set_num_samples(num_samples) 524 return c_sampler 525 526 527class RandomSampler(BuiltinSampler): 528 """ 529 Samples the elements randomly. 530 531 Args: 532 replacement (bool, optional): If True, put the sample ID back for the next draw. Default: ``False``. 533 num_samples (int, optional): Number of elements to sample. Default: ``None`` , which means sample all elements. 534 535 Raises: 536 TypeError: If `replacement` is not of type bool. 537 TypeError: If `num_samples` is not of type int. 538 ValueError: If `num_samples` is a negative value. 539 540 Examples: 541 >>> import mindspore.dataset as ds 542 >>> # creates a RandomSampler 543 >>> sampler = ds.RandomSampler() 544 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, 545 ... num_parallel_workers=8, 546 ... sampler=sampler) 547 """ 548 549 def __init__(self, replacement=False, num_samples=None): 550 if not isinstance(replacement, bool): 551 raise TypeError("replacement must be a boolean value but was: {}.".format(replacement)) 552 553 if num_samples is not None: 554 if not isinstance(num_samples, int): 555 raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) 556 if num_samples < 0 or num_samples > validator.INT64_MAX: 557 raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" 558 .format(0, validator.INT64_MAX)) 559 560 self.deterministic = False 561 self.replacement = replacement 562 self.reshuffle_each_epoch = True 563 super().__init__(num_samples) 564 565 def parse(self): 566 """ Parse the sampler.""" 567 num_samples = self.num_samples if self.num_samples is not None else 0 568 replacement = self.replacement if self.replacement is not None else False 569 c_sampler = cde.RandomSamplerObj(replacement, num_samples, self.reshuffle_each_epoch) 570 c_child_sampler = self.parse_child() 571 c_sampler.add_child(c_child_sampler) 572 return c_sampler 573 574 def parse_for_minddataset(self): 575 """Parse the sampler for MindRecord.""" 576 num_samples = self.num_samples if self.num_samples is not None else 0 577 c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch) 578 c_child_sampler = self.parse_child_for_minddataset() 579 c_sampler.add_child(c_child_sampler) 580 c_sampler.set_num_samples(num_samples) 581 return c_sampler 582 583 def is_shuffled(self): 584 return True 585 586 def is_sharded(self): 587 if self.child_sampler is None: 588 return False 589 590 return self.child_sampler.is_sharded() 591 592 593class SequentialSampler(BuiltinSampler): 594 """ 595 Samples the dataset elements sequentially that is equivalent to not using a sampler. 596 597 Args: 598 start_index (int, optional): Index to start sampling at. Default: ``None`` , start at first ID. 599 num_samples (int, optional): Number of elements to sample. Default: ``None`` , which means sample all elements. 600 601 Raises: 602 TypeError: If `start_index` is not of type int. 603 TypeError: If `num_samples` is not of type int. 604 RuntimeError: If `start_index` is a negative value. 605 ValueError: If `num_samples` is a negative value. 606 607 Examples: 608 >>> import mindspore.dataset as ds 609 >>> # creates a SequentialSampler 610 >>> sampler = ds.SequentialSampler() 611 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, 612 ... num_parallel_workers=8, 613 ... sampler=sampler) 614 """ 615 616 def __init__(self, start_index=None, num_samples=None): 617 if start_index is not None and not isinstance(start_index, int): 618 raise TypeError("start_index must be integer but was: {}.".format(start_index)) 619 620 if num_samples is not None: 621 if not isinstance(num_samples, int): 622 raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) 623 if num_samples < 0 or num_samples > validator.INT64_MAX: 624 raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" 625 .format(0, validator.INT64_MAX)) 626 627 self.start_index = start_index 628 super().__init__(num_samples) 629 630 def parse(self): 631 """ Parse the sampler.""" 632 start_index = self.start_index if self.start_index is not None else 0 633 num_samples = self.num_samples if self.num_samples is not None else 0 634 c_sampler = cde.SequentialSamplerObj(start_index, num_samples) 635 c_child_sampler = self.parse_child() 636 c_sampler.add_child(c_child_sampler) 637 return c_sampler 638 639 def parse_for_minddataset(self): 640 """Parse the sampler for MindRecord.""" 641 start_index = self.start_index if self.start_index is not None else 0 642 num_samples = self.num_samples if self.num_samples is not None else 0 643 c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index) 644 c_child_sampler = self.parse_child_for_minddataset() 645 c_sampler.add_child(c_child_sampler) 646 c_sampler.set_num_samples(num_samples) 647 return c_sampler 648 649 def is_shuffled(self): 650 if self.child_sampler is None: 651 return False 652 653 return self.child_sampler.is_shuffled() 654 655 def is_sharded(self): 656 if self.child_sampler is None: 657 return False 658 659 return self.child_sampler.is_sharded() 660 661 662class SubsetSampler(BuiltinSampler): 663 """ 664 Samples the elements from a sequence of indices. 665 666 Args: 667 indices (Iterable): A sequence of indices (Any iterable Python object but string). 668 num_samples (int, optional): Number of elements to sample. Default: ``None`` , which means sample all elements. 669 670 Raises: 671 TypeError: If elements of `indices` are not of type number. 672 TypeError: If `num_samples` is not of type int. 673 ValueError: If `num_samples` is a negative value. 674 675 Examples: 676 >>> import mindspore.dataset as ds 677 >>> indices = [0, 1, 2, 3, 4, 5] 678 >>> 679 >>> # creates a SubsetSampler, will sample from the provided indices 680 >>> sampler = ds.SubsetSampler(indices) 681 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, 682 ... num_parallel_workers=8, 683 ... sampler=sampler) 684 """ 685 686 def __init__(self, indices, num_samples=None): 687 def _get_sample_ids_as_list(sampler, number_of_samples=None): 688 if number_of_samples is None: 689 return list(sampler) 690 691 if isinstance(sampler, list): 692 return sampler[:number_of_samples] 693 694 return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] 695 696 if num_samples is not None: 697 if not isinstance(num_samples, int): 698 raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) 699 if num_samples < 0 or num_samples > validator.INT64_MAX: 700 raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" 701 .format(0, validator.INT64_MAX)) 702 703 if not isinstance(indices, str) and validator.is_iterable(indices): 704 indices = _get_sample_ids_as_list(indices, num_samples) 705 elif isinstance(indices, int): 706 indices = [indices] 707 else: 708 raise TypeError('Unsupported sampler object of type ({})'.format(type(indices))) 709 710 for i, item in enumerate(indices): 711 if not isinstance(item, (int, np.integer)): 712 raise TypeError("SubsetSampler: Type of indices element must be int, " 713 "but got list[{}]: {}, type: {}.".format(i, item, type(item))) 714 715 self.indices = indices 716 super().__init__(num_samples) 717 718 def parse(self): 719 """ Parse the sampler.""" 720 num_samples = self.num_samples if self.num_samples is not None else 0 721 c_sampler = cde.SubsetSamplerObj(self.indices, num_samples) 722 c_child_sampler = self.parse_child() 723 c_sampler.add_child(c_child_sampler) 724 return c_sampler 725 726 def is_shuffled(self): 727 return False 728 729 def is_sharded(self): 730 if self.child_sampler is None: 731 return False 732 733 return self.child_sampler.is_sharded() 734 735 def parse_for_minddataset(self): 736 """Parse the sampler for MindRecord.""" 737 c_sampler = cde.MindrecordSubsetSampler(self.indices) 738 c_child_sampler = self.parse_child_for_minddataset() 739 c_sampler.add_child(c_child_sampler) 740 c_sampler.set_num_samples(self.get_num_samples()) 741 return c_sampler 742 743 def get_num_samples(self): 744 num_samples = super().get_num_samples() 745 if num_samples is None: 746 return len(self.indices) 747 748 return min(len(self.indices), num_samples) 749 750 751class SubsetRandomSampler(SubsetSampler): 752 """ 753 Samples the elements randomly from a sequence of indices. 754 755 Args: 756 indices (Iterable): A sequence of indices (Any iterable Python object but string). 757 num_samples (int, optional): Number of elements to sample. Default: ``None`` , which means sample all elements. 758 759 Raises: 760 TypeError: If elements of `indices` are not of type number. 761 TypeError: If `num_samples` is not of type int. 762 ValueError: If `num_samples` is a negative value. 763 764 Examples: 765 >>> import mindspore.dataset as ds 766 >>> indices = [0, 1, 2, 3, 7, 88, 119] 767 >>> 768 >>> # create a SubsetRandomSampler, will sample from the provided indices 769 >>> sampler = ds.SubsetRandomSampler(indices) 770 >>> data = ds.ImageFolderDataset(image_folder_dataset_dir, num_parallel_workers=8, sampler=sampler) 771 """ 772 773 def parse(self): 774 """ Parse the sampler.""" 775 num_samples = self.num_samples if self.num_samples is not None else 0 776 c_sampler = cde.SubsetRandomSamplerObj(self.indices, num_samples) 777 c_child_sampler = self.parse_child() 778 c_sampler.add_child(c_child_sampler) 779 return c_sampler 780 781 def is_shuffled(self): 782 return True 783 784 def parse_for_minddataset(self): 785 """Parse the sampler for MindRecord.""" 786 c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed()) 787 c_child_sampler = self.parse_child_for_minddataset() 788 c_sampler.add_child(c_child_sampler) 789 c_sampler.set_num_samples(self.get_num_samples()) 790 return c_sampler 791 792 793class IterSampler(Sampler): 794 """ 795 User provided an iterable object without inheriting from our Sampler class. 796 797 Note: 798 This class exists to allow handshake logic between dataset operations and user defined samplers. 799 By constructing this object we avoid the user having to inherit from our Sampler class. 800 801 Args: 802 sampler (iterable object): an user defined iterable object. 803 num_samples (int, optional): Number of elements to sample. Default: ``None`` , which means sample all elements. 804 805 Examples: 806 >>> import mindspore.dataset as ds 807 >>> class MySampler: 808 ... def __iter__(self): 809 ... for i in range(99, -1, -1): 810 ... yield i 811 812 >>> # creates an IterSampler 813 >>> sampler = ds.IterSampler(sampler=MySampler()) 814 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, 815 ... num_parallel_workers=8, 816 ... sampler=sampler) 817 """ 818 819 def __init__(self, sampler, num_samples=None): 820 if num_samples is None: 821 num_samples = len(list(sampler)) 822 super().__init__(num_samples=num_samples) 823 self.sampler = sampler 824 825 def __iter__(self): 826 return iter(self.sampler) 827 828 829class WeightedRandomSampler(BuiltinSampler): 830 """ 831 Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). 832 833 Args: 834 weights (list[float, int]): A sequence of weights, not necessarily summing up to 1. 835 num_samples (int, optional): Number of elements to sample. Default: ``None`` , which means sample all elements. 836 replacement (bool): If ``True``, put the sample ID back for the next draw. Default: ``True``. 837 838 Raises: 839 TypeError: If elements of `weights` are not of type number. 840 TypeError: If `num_samples` is not of type int. 841 TypeError: If `replacement` is not of type bool. 842 RuntimeError: If `weights` is empty or all zero. 843 ValueError: If `num_samples` is a negative value. 844 845 Examples: 846 >>> import mindspore.dataset as ds 847 >>> weights = [0.9, 0.01, 0.4, 0.8, 0.1, 0.1, 0.3] 848 >>> 849 >>> # creates a WeightedRandomSampler that will sample 4 elements without replacement 850 >>> sampler = ds.WeightedRandomSampler(weights, 4) 851 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, 852 ... num_parallel_workers=8, 853 ... sampler=sampler) 854 """ 855 856 def __init__(self, weights, num_samples=None, replacement=True): 857 if not isinstance(weights, list): 858 weights = [weights] 859 860 for ind, w in enumerate(weights): 861 if not isinstance(w, numbers.Number): 862 raise TypeError("type of weights element must be number, " 863 "but got w[{}]: {}, type: {}.".format(ind, w, type(w))) 864 865 if num_samples is not None: 866 if not isinstance(num_samples, int): 867 raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) 868 if num_samples < 0 or num_samples > validator.INT64_MAX: 869 raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" 870 .format(0, validator.INT64_MAX)) 871 872 if not isinstance(replacement, bool): 873 raise TypeError("replacement must be a boolean value but was: {}.".format(replacement)) 874 875 self.weights = weights 876 self.replacement = replacement 877 super().__init__(num_samples) 878 879 def parse(self): 880 """ Parse the sampler.""" 881 num_samples = self.num_samples if self.num_samples is not None else 0 882 replacement = self.replacement if self.replacement is not None else True 883 c_sampler = cde.WeightedRandomSamplerObj(self.weights, num_samples, replacement) 884 c_child_sampler = self.parse_child() 885 c_sampler.add_child(c_child_sampler) 886 return c_sampler 887 888 def is_shuffled(self): 889 return True 890 891 def is_sharded(self): 892 if self.child_sampler is None: 893 return False 894 895 return self.child_sampler.is_sharded() 896