1# Copyright 2019-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16The sampler module provides several samplers to generate data from datasets. 17The provided samplers include: DistributedSampler, PKSampler, RandomSampler, 18SequentialSampler, SubsetRandomSampler, and WeightedRandomSampler. 19Users can also define a custom sampler by extending from the Sampler class. 20""" 21 22import numbers 23import numpy as np 24import mindspore._c_dataengine as cde 25import mindspore.dataset as ds 26from ..core import validator_helpers as validator 27 28 29def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): 30 """ 31 Create sampler based on user input. 32 33 Args: 34 num_samples (int): Number of samples. 35 input_sampler (Union[Iterable, Sampler]): Sampler from user. 36 shuffle (bool): Shuffle. 37 num_shards (int): Number of shard for sharding. 38 shard_id (int): Shard ID. 39 40 Returns: 41 Sampler, sampler selected based on user input. 42 """ 43 44 if input_sampler is not None: 45 # If the user provided a sampler, then it doesn't matter what the other args are because 46 # we are being asked specifically to use the given sampler. 47 # That means the following arguments: num_shards, shard_id, shuffle, num_samples should all 48 # be None. Consider this example: 49 # sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle) 50 # data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1) 51 # In this case, the user has given different sample-related arguments that contradict each other. 52 # To prevent this, only allow the user to manually specify the sampler if those arguments are all None 53 if (isinstance(input_sampler, BuiltinSampler) and 54 (any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))): 55 raise ValueError( 56 'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},' 57 ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) 58 if isinstance(input_sampler, BuiltinSampler): 59 return input_sampler 60 if not isinstance(input_sampler, str) and isinstance(input_sampler, (np.ndarray, list)): 61 return SubsetSampler(input_sampler, num_samples) 62 if not isinstance(input_sampler, str) and validator.is_iterable(input_sampler): 63 # in this case, the user passed in their own sampler object that's not of type BuiltinSampler 64 return IterSampler(input_sampler, num_samples) 65 if isinstance(input_sampler, int): 66 return SubsetSampler([input_sampler]) 67 raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler))) 68 if shuffle is None: 69 if num_shards is not None: 70 # If shuffle is not specified, sharding enabled, use distributed random sampler 71 shuffle = True 72 return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) 73 # If shuffle is not specified, sharding disabled, use random sampler 74 if num_samples is not None and num_samples != 0: 75 return RandomSampler(replacement=True, num_samples=num_samples) 76 return RandomSampler(num_samples=num_samples) 77 if shuffle is True: 78 if num_shards is not None: 79 # If shuffle enabled, sharding enabled, use distributed random sampler 80 return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) 81 # If shuffle enabled, sharding disabled, use random sampler 82 if num_samples is not None: 83 return RandomSampler(replacement=True, num_samples=num_samples) 84 return RandomSampler(num_samples=num_samples) 85 if num_shards is not None: 86 # If shuffle disabled, sharding enabled, use distributed sequential sampler 87 return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) 88 # If shuffle disabled, sharding disabled, use sequential sampler 89 return SequentialSampler(num_samples=num_samples) 90 91 92class BuiltinSampler: 93 """ 94 Base class for BuiltinSampler. 95 96 User should not extend this class. 97 """ 98 99 def __init__(self, num_samples=None): 100 self.child_sampler = None 101 self.num_samples = num_samples 102 103 def parse(self): 104 """ Parse the sampler.""" 105 106 def add_child(self, sampler): 107 """ 108 Add a sub-sampler for given sampler. The sub-sampler will receive all data from the 109 output of parent sampler and apply its sample logic to return new samples. 110 111 Args: 112 sampler (Sampler): Object used to choose samples from the dataset. Only builtin 113 samplers(DistributedSampler, PKSampler, RandomSampler, SequentialSampler, 114 SubsetRandomSampler, WeightedRandomSampler) are supported. 115 116 Examples: 117 >>> sampler = ds.SequentialSampler(start_index=0, num_samples=3) 118 >>> sampler.add_child(ds.RandomSampler(num_samples=2)) 119 >>> dataset = ds.Cifar10Dataset(cifar10_dataset_dir, sampler=sampler) 120 """ 121 self.child_sampler = sampler 122 123 def get_child(self): 124 """ 125 Get the child sampler of given sampler. 126 127 Returns: 128 Sampler, The child sampler of given sampler. 129 130 Examples: 131 >>> sampler = ds.SequentialSampler(start_index=0, num_samples=3) 132 >>> sampler.add_child(ds.RandomSampler(num_samples=2)) 133 >>> child_sampler = sampler.get_child() 134 """ 135 return self.child_sampler 136 137 def parse_child(self): 138 """ Parse the child sampler. """ 139 c_child_sampler = None 140 if self.child_sampler is not None: 141 c_child_sampler = self.child_sampler.parse() 142 return c_child_sampler 143 144 def parse_child_for_minddataset(self): 145 """ Parse the child sampler for MindRecord. """ 146 c_child_sampler = None 147 if self.child_sampler is not None: 148 c_child_sampler = self.child_sampler.parse_for_minddataset() 149 return c_child_sampler 150 151 def is_shuffled(self): 152 """ Not implemented. """ 153 raise NotImplementedError("Sampler must implement is_shuffled.") 154 155 def is_sharded(self): 156 """ Not implemented. """ 157 raise NotImplementedError("Sampler must implement is_sharded.") 158 159 def get_num_samples(self): 160 """ 161 All samplers can contain a numeric num_samples value (or it can be set to None). 162 A child sampler can exist or be None. 163 If a child sampler exists, then the child sampler count can be a numeric value or None. 164 These conditions impact the resultant sampler count that is used. 165 The following table shows the possible results from calling this function. 166 167 .. list-table:: 168 :widths: 25 25 25 25 169 :header-rows: 1 170 171 * - child sampler 172 - num_samples 173 - child_samples 174 - result 175 * - T 176 - x 177 - y 178 - min(x, y) 179 * - T 180 - x 181 - None 182 - x 183 * - T 184 - None 185 - y 186 - y 187 * - T 188 - None 189 - None 190 - None 191 * - None 192 - x 193 - n/a 194 - x 195 * - None 196 - None 197 - n/a 198 - None 199 200 Returns: 201 int, the number of samples, or None. 202 203 Examples: 204 >>> sampler = ds.SequentialSampler(start_index=0, num_samples=3) 205 >>> num_samplers = sampler.get_num_samples() 206 """ 207 if self.child_sampler is not None: 208 child_samples = self.child_sampler.get_num_samples() 209 if self.num_samples is not None: 210 if child_samples is not None: 211 return min(self.num_samples, child_samples) 212 213 return self.num_samples 214 215 return child_samples 216 217 return self.num_samples 218 219 220class Sampler(BuiltinSampler): 221 """ 222 Base class for user defined sampler. 223 A user defined sampler can be used with any existing dataset with sampler support. 224 225 A required _iter_() method should by overridden by the user for sample index generation. 226 An optional reset() method can be overridden for per repeat reset, 227 228 dataset_size and num_samples will be set by dataset once a dataset iterator is created. 229 230 Examples: 231 >>> class ReverseSampler(ds.Sampler): 232 ... def __iter__(self): 233 ... for i in range(self.dataset_size - 1, -1, -1): 234 ... yield i 235 >>> 236 >>> ds = ds.ImageFolderDataset(image_folder_dataset_dir, sampler=ReverseSampler()) 237 """ 238 239 def __init__(self, num_samples=None): 240 super().__init__(num_samples) 241 self.dataset_size = 0 242 self.child_sampler = None 243 self.num_samples = num_samples 244 245 def __iter__(self): 246 """ 247 User defined iterator, must be overridden. 248 _handshake is guaranteed to be called prior to iterator construction. 249 """ 250 raise NotImplementedError 251 252 def reset(self): 253 """ 254 Per repeat reset callback, override this method if necessary 255 """ 256 257 # Initialization handshake callback 258 # Do not override this method! 259 def _handshake(self, ds_size, num_samples): 260 self.dataset_size = ds_size 261 self.num_samples = num_samples 262 263 # Indices fetcher 264 # Do not override this method! 265 # pylint: disable=missing-docstring 266 def _get_indices(self): 267 sampler_iter = iter(self) 268 ret = [] 269 for _ in range(self.num_samples): 270 try: 271 idx = next(sampler_iter) 272 ret.append(idx) 273 except StopIteration: 274 break 275 indices = np.array(ret) 276 if indices.dtype == object: 277 raise RuntimeError("Fetched indices can not be converted to a valid ndarray.") 278 return indices 279 280 # Instance fetcher 281 # Do not override this method! 282 def parse(self): 283 """ Parse the sampler.""" 284 num_samples = self.num_samples if self.num_samples is not None else 0 285 c_sampler = cde.PreBuiltSamplerObj(num_samples, self) 286 c_child_sampler = self.parse_child() 287 c_sampler.add_child(c_child_sampler) 288 return c_sampler 289 290 def add_child(self, sampler): 291 self.child_sampler = sampler 292 293 def get_child(self): 294 return self.child_sampler 295 296 def parse_child(self): 297 c_child_sampler = None 298 if self.child_sampler is not None: 299 c_child_sampler = self.child_sampler.parse() 300 301 return c_child_sampler 302 303 def is_shuffled(self): 304 if self.child_sampler is None: 305 return False 306 307 return self.child_sampler.is_shuffled() 308 309 def is_sharded(self): 310 if self.child_sampler is None: 311 return False 312 313 return self.child_sampler.is_sharded() 314 315 def get_num_samples(self): 316 if self.num_samples is None: 317 return None 318 return self._get_indices().size 319 320 321class DistributedSampler(BuiltinSampler): 322 """ 323 A sampler that accesses a shard of the dataset, it helps divide dataset into multi-subset for distributed training. 324 325 Args: 326 num_shards (int): Number of shards to divide the dataset into. 327 shard_id (int): Shard ID of the current shard, which should within the range of [0, num_shards-1]. 328 shuffle (bool, optional): If True, the indices are shuffled, otherwise it will not be shuffled(default=True). 329 num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements). 330 offset(int, optional): The starting shard ID where the elements in the dataset are sent to, which 331 should be no more than num_shards. This parameter is only valid when a ConcatDataset takes 332 a DistributedSampler as its sampler. It will affect the number of samples of per shard 333 (default=-1, which means each shard has same number of samples). 334 335 Examples: 336 >>> # creates a distributed sampler with 10 shards in total. This shard is shard 5. 337 >>> sampler = ds.DistributedSampler(10, 5) 338 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, 339 ... num_parallel_workers=8, 340 ... sampler=sampler) 341 342 Raises: 343 TypeError: If num_shards is not an integer value. 344 TypeError: If shard_id is not an integer value. 345 TypeError: If shuffle is not a boolean value. 346 TypeError: If num_samples is not an integer value. 347 TypeError: If offset is not an integer value. 348 ValueError: If num_samples is a negative value. 349 RuntimeError: If num_shards is not a positive value. 350 RuntimeError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards. 351 RuntimeError: If offset is greater than num_shards. 352 """ 353 354 def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1): 355 if not isinstance(num_shards, int): 356 raise TypeError("num_shards must be integer but was: {}.".format(num_shards)) 357 358 if not isinstance(shard_id, int): 359 raise TypeError("shard_id must be integer but was: {}.".format(shard_id)) 360 361 if not isinstance(shuffle, bool): 362 raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle)) 363 364 if num_samples is not None: 365 if not isinstance(num_samples, int): 366 raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) 367 if num_samples < 0 or num_samples > validator.INT64_MAX: 368 raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" 369 .format(0, validator.INT64_MAX)) 370 371 if not isinstance(offset, int): 372 raise TypeError("offset must be integer but was: {}.".format(offset)) 373 374 self.num_shards = num_shards 375 self.shard_id = shard_id 376 self.shuffle = shuffle 377 self.seed = 0 378 self.offset = offset 379 super().__init__(num_samples) 380 381 def parse(self): 382 """ Parse the sampler.""" 383 num_samples = self.num_samples if self.num_samples is not None else 0 384 shuffle = self.shuffle if self.shuffle is not None else True 385 offset = self.offset if self.offset is not None else -1 386 # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle 387 self.seed += 1 388 c_sampler = cde.DistributedSamplerObj(self.num_shards, self.shard_id, 389 shuffle, num_samples, self.seed, offset, True) 390 c_child_sampler = self.parse_child() 391 c_sampler.add_child(c_child_sampler) 392 return c_sampler 393 394 def parse_for_minddataset(self): 395 """ Parse the sampler for MindRecord.""" 396 num_samples = self.num_samples if self.num_samples is not None else 0 397 shuffle = self.shuffle if self.shuffle is not None else True 398 c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, shuffle, 399 self.seed, num_samples, self.offset) 400 c_child_sampler = self.parse_child_for_minddataset() 401 c_sampler.add_child(c_child_sampler) 402 return c_sampler 403 404 def is_shuffled(self): 405 if self.child_sampler is None: 406 return self.shuffle 407 408 return self.child_sampler.is_shuffled() 409 410 def is_sharded(self): 411 if self.child_sampler is None: 412 return self.num_shards > 1 413 414 return self.child_sampler.is_sharded() 415 416 def set_offset(self, offset): 417 self.offset = offset 418 return self 419 420 421class PKSampler(BuiltinSampler): 422 """ 423 Samples K elements for each P class in the dataset. 424 425 Args: 426 num_val (int): Number of elements to sample for each class. 427 num_class (int, optional): Number of classes to sample (default=None, sample all classes). 428 The parameter does not supported to specify currently. 429 shuffle (bool, optional): If True, the class IDs are shuffled, otherwise it will not be 430 shuffled (default=False). 431 class_column (str, optional): Name of column with class labels for MindDataset (default='label'). 432 num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements). 433 434 Examples: 435 >>> # creates a PKSampler that will get 3 samples from every class. 436 >>> sampler = ds.PKSampler(3) 437 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, 438 ... num_parallel_workers=8, 439 ... sampler=sampler) 440 441 Raises: 442 TypeError: If shuffle is not a boolean value. 443 TypeError: If class_column is not a str value. 444 TypeError: If num_samples is not an integer value. 445 NotImplementedError: If num_class is not None. 446 RuntimeError: If num_val is not a positive value. 447 ValueError: If num_samples is a negative value. 448 """ 449 450 def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None): 451 if not isinstance(num_val, int): 452 raise TypeError("num_val must be integer but was: {}.".format(num_val)) 453 454 if num_class is not None: 455 raise NotImplementedError("Not supported to specify num_class for PKSampler.") 456 457 if not isinstance(shuffle, bool): 458 raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle)) 459 460 if not isinstance(class_column, str): 461 raise TypeError("class_column must be a str value but was: {}.".format(class_column)) 462 463 if num_samples is not None: 464 if not isinstance(num_samples, int): 465 raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) 466 if num_samples < 0 or num_samples > validator.INT64_MAX: 467 raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" 468 .format(0, validator.INT64_MAX)) 469 470 self.num_val = num_val 471 self.shuffle = shuffle 472 self.class_column = class_column # work for minddataset 473 super().__init__(num_samples) 474 475 def parse(self): 476 """ Parse the sampler.""" 477 num_samples = self.num_samples if self.num_samples is not None else 0 478 shuffle = self.shuffle if self.shuffle is not None else False 479 c_sampler = cde.PKSamplerObj(self.num_val, shuffle, num_samples) 480 c_child_sampler = self.parse_child() 481 c_sampler.add_child(c_child_sampler) 482 return c_sampler 483 484 def is_shuffled(self): 485 if self.child_sampler is None: 486 return self.shuffle 487 488 return self.child_sampler.is_shuffled() 489 490 def is_sharded(self): 491 if self.child_sampler is None: 492 return False 493 494 return self.child_sampler.is_sharded() 495 496 def parse_for_minddataset(self): 497 """Parse the sampler for MindRecord.""" 498 if not self.class_column or not isinstance(self.class_column, str): 499 raise ValueError("class_column should be a not empty string value, \ 500 but got class_column: {}.".format(self.class_column)) 501 num_samples = self.num_samples if self.num_samples is not None else 0 502 c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples) 503 c_child_sampler = self.parse_child_for_minddataset() 504 c_sampler.add_child(c_child_sampler) 505 return c_sampler 506 507 508class RandomSampler(BuiltinSampler): 509 """ 510 Samples the elements randomly. 511 512 Args: 513 replacement (bool, optional): If True, put the sample ID back for the next draw (default=False). 514 num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements). 515 516 Examples: 517 >>> # creates a RandomSampler 518 >>> sampler = ds.RandomSampler() 519 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, 520 ... num_parallel_workers=8, 521 ... sampler=sampler) 522 523 Raises: 524 TypeError: If replacement is not a boolean value. 525 TypeError: If num_samples is not an integer value. 526 ValueError: If num_samples is a negative value. 527 """ 528 529 def __init__(self, replacement=False, num_samples=None): 530 if not isinstance(replacement, bool): 531 raise TypeError("replacement must be a boolean value but was: {}.".format(replacement)) 532 533 if num_samples is not None: 534 if not isinstance(num_samples, int): 535 raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) 536 if num_samples < 0 or num_samples > validator.INT64_MAX: 537 raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" 538 .format(0, validator.INT64_MAX)) 539 540 self.deterministic = False 541 self.replacement = replacement 542 self.reshuffle_each_epoch = True 543 super().__init__(num_samples) 544 545 def parse(self): 546 """ Parse the sampler.""" 547 num_samples = self.num_samples if self.num_samples is not None else 0 548 replacement = self.replacement if self.replacement is not None else False 549 c_sampler = cde.RandomSamplerObj(replacement, num_samples, self.reshuffle_each_epoch) 550 c_child_sampler = self.parse_child() 551 c_sampler.add_child(c_child_sampler) 552 return c_sampler 553 554 def parse_for_minddataset(self): 555 """Parse the sampler for MindRecord.""" 556 num_samples = self.num_samples if self.num_samples is not None else 0 557 c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch) 558 c_child_sampler = self.parse_child_for_minddataset() 559 c_sampler.add_child(c_child_sampler) 560 return c_sampler 561 562 def is_shuffled(self): 563 return True 564 565 def is_sharded(self): 566 if self.child_sampler is None: 567 return False 568 569 return self.child_sampler.is_sharded() 570 571 572class SequentialSampler(BuiltinSampler): 573 """ 574 Samples the dataset elements sequentially that is equivalent to not using a sampler. 575 576 Args: 577 start_index (int, optional): Index to start sampling at. (default=None, start at first ID) 578 num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements). 579 580 Examples: 581 >>> # creates a SequentialSampler 582 >>> sampler = ds.SequentialSampler() 583 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, 584 ... num_parallel_workers=8, 585 ... sampler=sampler) 586 587 Raises: 588 TypeError: If start_index is not an integer value. 589 TypeError: If num_samples is not an integer value. 590 RuntimeError: If start_index is a negative value. 591 ValueError: If num_samples is a negative value. 592 """ 593 594 def __init__(self, start_index=None, num_samples=None): 595 if start_index is not None and not isinstance(start_index, int): 596 raise TypeError("start_index must be integer but was: {}.".format(start_index)) 597 598 if num_samples is not None: 599 if not isinstance(num_samples, int): 600 raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) 601 if num_samples < 0 or num_samples > validator.INT64_MAX: 602 raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" 603 .format(0, validator.INT64_MAX)) 604 605 self.start_index = start_index 606 super().__init__(num_samples) 607 608 def parse(self): 609 """ Parse the sampler.""" 610 start_index = self.start_index if self.start_index is not None else 0 611 num_samples = self.num_samples if self.num_samples is not None else 0 612 c_sampler = cde.SequentialSamplerObj(start_index, num_samples) 613 c_child_sampler = self.parse_child() 614 c_sampler.add_child(c_child_sampler) 615 return c_sampler 616 617 def parse_for_minddataset(self): 618 """Parse the sampler for MindRecord.""" 619 start_index = self.start_index if self.start_index is not None else 0 620 num_samples = self.num_samples if self.num_samples is not None else 0 621 c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index) 622 c_child_sampler = self.parse_child_for_minddataset() 623 c_sampler.add_child(c_child_sampler) 624 return c_sampler 625 626 def is_shuffled(self): 627 if self.child_sampler is None: 628 return False 629 630 return self.child_sampler.is_shuffled() 631 632 def is_sharded(self): 633 if self.child_sampler is None: 634 return False 635 636 return self.child_sampler.is_sharded() 637 638 639class SubsetSampler(BuiltinSampler): 640 """ 641 Samples the elements from a sequence of indices. 642 643 Args: 644 indices (Any iterable Python object but string): A sequence of indices. 645 num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements). 646 647 Examples: 648 >>> indices = [0, 1, 2, 3, 4, 5] 649 >>> 650 >>> # creates a SubsetSampler, will sample from the provided indices 651 >>> sampler = ds.SubsetSampler(indices) 652 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, 653 ... num_parallel_workers=8, 654 ... sampler=sampler) 655 656 Raises: 657 TypeError: If type of indices element is not a number. 658 TypeError: If num_samples is not an integer value. 659 ValueError: If num_samples is a negative value. 660 """ 661 662 def __init__(self, indices, num_samples=None): 663 def _get_sample_ids_as_list(sampler, number_of_samples=None): 664 if number_of_samples is None: 665 return list(sampler) 666 667 if isinstance(sampler, list): 668 return sampler[:number_of_samples] 669 670 return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] 671 672 if num_samples is not None: 673 if not isinstance(num_samples, int): 674 raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) 675 if num_samples < 0 or num_samples > validator.INT64_MAX: 676 raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" 677 .format(0, validator.INT64_MAX)) 678 679 if not isinstance(indices, str) and validator.is_iterable(indices): 680 indices = _get_sample_ids_as_list(indices, num_samples) 681 elif isinstance(indices, int): 682 indices = [indices] 683 else: 684 raise TypeError('Unsupported sampler object of type ({})'.format(type(indices))) 685 686 for i, item in enumerate(indices): 687 if not isinstance(item, (int, np.integer)): 688 raise TypeError("SubsetSampler: Type of indices element must be int, " 689 "but got list[{}]: {}, type: {}.".format(i, item, type(item))) 690 691 self.indices = indices 692 super().__init__(num_samples) 693 694 def parse(self): 695 """ Parse the sampler.""" 696 num_samples = self.num_samples if self.num_samples is not None else 0 697 c_sampler = cde.SubsetSamplerObj(self.indices, num_samples) 698 c_child_sampler = self.parse_child() 699 c_sampler.add_child(c_child_sampler) 700 return c_sampler 701 702 def is_shuffled(self): 703 return False 704 705 def is_sharded(self): 706 if self.child_sampler is None: 707 return False 708 709 return self.child_sampler.is_sharded() 710 711 def parse_for_minddataset(self): 712 """Parse the sampler for MindRecord.""" 713 c_sampler = cde.MindrecordSubsetSampler(self.indices) 714 c_child_sampler = self.parse_child_for_minddataset() 715 c_sampler.add_child(c_child_sampler) 716 return c_sampler 717 718 def get_num_samples(self): 719 num_samples = super().get_num_samples() 720 if num_samples is None: 721 return len(self.indices) 722 723 return min(len(self.indices), num_samples) 724 725 726class SubsetRandomSampler(SubsetSampler): 727 """ 728 Samples the elements randomly from a sequence of indices. 729 730 Args: 731 indices (Any iterable Python object but string): A sequence of indices. 732 num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements). 733 734 Examples: 735 >>> indices = [0, 1, 2, 3, 7, 88, 119] 736 >>> 737 >>> # create a SubsetRandomSampler, will sample from the provided indices 738 >>> sampler = ds.SubsetRandomSampler(indices) 739 >>> data = ds.ImageFolderDataset(image_folder_dataset_dir, num_parallel_workers=8, sampler=sampler) 740 741 Raises: 742 TypeError: If type of indices element is not a number. 743 TypeError: If num_samples is not an integer value. 744 ValueError: If num_samples is a negative value. 745 """ 746 747 def parse(self): 748 """ Parse the sampler.""" 749 num_samples = self.num_samples if self.num_samples is not None else 0 750 c_sampler = cde.SubsetRandomSamplerObj(self.indices, num_samples) 751 c_child_sampler = self.parse_child() 752 c_sampler.add_child(c_child_sampler) 753 return c_sampler 754 755 def is_shuffled(self): 756 return True 757 758 def parse_for_minddataset(self): 759 """Parse the sampler for MindRecord.""" 760 c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed()) 761 c_child_sampler = self.parse_child_for_minddataset() 762 c_sampler.add_child(c_child_sampler) 763 return c_sampler 764 765 766class IterSampler(Sampler): 767 """ 768 User provided an iterable object without inheriting from our Sampler class. 769 770 Note: 771 This class exists to allow handshake logic between dataset operators and user defined samplers. 772 By constructing this object we avoid the user having to inherit from our Sampler class. 773 774 Args: 775 sampler (iterable object): an user defined iterable object. 776 num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements). 777 778 Examples: 779 >>> class MySampler: 780 ... def __iter__(self): 781 ... for i in range(99, -1, -1): 782 ... yield i 783 784 >>> # creates an IterSampler 785 >>> sampler = ds.IterSampler(sampler=MySampler()) 786 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, 787 ... num_parallel_workers=8, 788 ... sampler=sampler) 789 """ 790 791 def __init__(self, sampler, num_samples=None): 792 if num_samples is None: 793 num_samples = len(list(sampler)) 794 super().__init__(num_samples=num_samples) 795 self.sampler = sampler 796 797 def __iter__(self): 798 return iter(self.sampler) 799 800 801class WeightedRandomSampler(BuiltinSampler): 802 """ 803 Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). 804 805 Args: 806 weights (list[float, int]): A sequence of weights, not necessarily summing up to 1. 807 num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements). 808 replacement (bool): If True, put the sample ID back for the next draw (default=True). 809 810 Examples: 811 >>> weights = [0.9, 0.01, 0.4, 0.8, 0.1, 0.1, 0.3] 812 >>> 813 >>> # creates a WeightedRandomSampler that will sample 4 elements without replacement 814 >>> sampler = ds.WeightedRandomSampler(weights, 4) 815 >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, 816 ... num_parallel_workers=8, 817 ... sampler=sampler) 818 819 Raises: 820 TypeError: If type of weights element is not a number. 821 TypeError: If num_samples is not an integer value. 822 TypeError: If replacement is not a boolean value. 823 RuntimeError: If `weights` is empty or all zero. 824 ValueError: If num_samples is a negative value. 825 """ 826 827 def __init__(self, weights, num_samples=None, replacement=True): 828 if not isinstance(weights, list): 829 weights = [weights] 830 831 for ind, w in enumerate(weights): 832 if not isinstance(w, numbers.Number): 833 raise TypeError("type of weights element must be number, " 834 "but got w[{}]: {}, type: {}.".format(ind, w, type(w))) 835 836 if num_samples is not None: 837 if not isinstance(num_samples, int): 838 raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) 839 if num_samples < 0 or num_samples > validator.INT64_MAX: 840 raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" 841 .format(0, validator.INT64_MAX)) 842 843 if not isinstance(replacement, bool): 844 raise TypeError("replacement must be a boolean value but was: {}.".format(replacement)) 845 846 self.weights = weights 847 self.replacement = replacement 848 super().__init__(num_samples) 849 850 def parse(self): 851 """ Parse the sampler.""" 852 num_samples = self.num_samples if self.num_samples is not None else 0 853 replacement = self.replacement if self.replacement is not None else True 854 c_sampler = cde.WeightedRandomSamplerObj(self.weights, num_samples, replacement) 855 c_child_sampler = self.parse_child() 856 c_sampler.add_child(c_child_sampler) 857 return c_sampler 858 859 def is_shuffled(self): 860 return True 861 862 def is_sharded(self): 863 if self.child_sampler is None: 864 return False 865 866 return self.child_sampler.is_sharded() 867