• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16The sampler module provides several samplers to generate data from datasets.
17The provided samplers include: DistributedSampler, PKSampler, RandomSampler,
18SequentialSampler, SubsetRandomSampler, and WeightedRandomSampler.
19Users can also define a custom sampler by extending from the Sampler class.
20"""
21
22import numbers
23import numpy as np
24import mindspore._c_dataengine as cde
25import mindspore.dataset as ds
26from ..core import validator_helpers as validator
27
28
29def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
30    """
31    Create sampler based on user input.
32
33    Args:
34        num_samples (int): Number of samples.
35        input_sampler (Union[Iterable, Sampler]): Sampler from user.
36        shuffle (bool): Shuffle.
37        num_shards (int): Number of shard for sharding.
38        shard_id (int): Shard ID.
39
40    Returns:
41        Sampler, sampler selected based on user input.
42    """
43
44    if input_sampler is not None:
45        # If the user provided a sampler, then it doesn't matter what the other args are because
46        # we are being asked specifically to use the given sampler.
47        # That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
48        # be None. Consider this example:
49        #     sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
50        #     data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
51        # In this case, the user has given different sample-related arguments that contradict each other.
52        # To prevent this, only allow the user to manually specify the sampler if those arguments are all None
53        if (isinstance(input_sampler, BuiltinSampler) and
54                (any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
55            raise ValueError(
56                'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
57                ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
58        if isinstance(input_sampler, BuiltinSampler):
59            return input_sampler
60        if not isinstance(input_sampler, str) and isinstance(input_sampler, (np.ndarray, list)):
61            return SubsetSampler(input_sampler, num_samples)
62        if not isinstance(input_sampler, str) and validator.is_iterable(input_sampler):
63            # in this case, the user passed in their own sampler object that's not of type BuiltinSampler
64            return IterSampler(input_sampler, num_samples)
65        if isinstance(input_sampler, int):
66            return SubsetSampler([input_sampler])
67        raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
68    if shuffle is None:
69        if num_shards is not None:
70            # If shuffle is not specified, sharding enabled, use distributed random sampler
71            shuffle = True
72            return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
73        # If shuffle is not specified, sharding disabled, use random sampler
74        if num_samples is not None and num_samples != 0:
75            return RandomSampler(replacement=True, num_samples=num_samples)
76        return RandomSampler(num_samples=num_samples)
77    if shuffle is True:
78        if num_shards is not None:
79            # If shuffle enabled, sharding enabled, use distributed random sampler
80            return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
81        # If shuffle enabled, sharding disabled, use random sampler
82        if num_samples is not None:
83            return RandomSampler(replacement=True, num_samples=num_samples)
84        return RandomSampler(num_samples=num_samples)
85    if num_shards is not None:
86        # If shuffle disabled, sharding enabled, use distributed sequential sampler
87        return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
88    # If shuffle disabled, sharding disabled, use sequential sampler
89    return SequentialSampler(num_samples=num_samples)
90
91
92class BuiltinSampler:
93    """
94    Base class for BuiltinSampler.
95
96    User should not extend this class.
97    """
98
99    def __init__(self, num_samples=None):
100        self.child_sampler = None
101        self.num_samples = num_samples
102
103    def parse(self):
104        """ Parse the sampler."""
105
106    def add_child(self, sampler):
107        """
108        Add a sub-sampler for given sampler. The sub-sampler will receive all data from the
109        output of parent sampler and apply its sample logic to return new samples.
110
111        Args:
112            sampler (Sampler): Object used to choose samples from the dataset. Only builtin
113                samplers(DistributedSampler, PKSampler, RandomSampler, SequentialSampler,
114                SubsetRandomSampler, WeightedRandomSampler) are supported.
115
116        Examples:
117            >>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
118            >>> sampler.add_child(ds.RandomSampler(num_samples=2))
119            >>> dataset = ds.Cifar10Dataset(cifar10_dataset_dir, sampler=sampler)
120        """
121        self.child_sampler = sampler
122
123    def get_child(self):
124        """
125        Get the child sampler of given sampler.
126
127        Returns:
128            Sampler, The child sampler of given sampler.
129
130        Examples:
131            >>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
132            >>> sampler.add_child(ds.RandomSampler(num_samples=2))
133            >>> child_sampler = sampler.get_child()
134        """
135        return self.child_sampler
136
137    def parse_child(self):
138        """ Parse the child sampler. """
139        c_child_sampler = None
140        if self.child_sampler is not None:
141            c_child_sampler = self.child_sampler.parse()
142        return c_child_sampler
143
144    def parse_child_for_minddataset(self):
145        """ Parse the child sampler for MindRecord. """
146        c_child_sampler = None
147        if self.child_sampler is not None:
148            c_child_sampler = self.child_sampler.parse_for_minddataset()
149        return c_child_sampler
150
151    def is_shuffled(self):
152        """ Not implemented. """
153        raise NotImplementedError("Sampler must implement is_shuffled.")
154
155    def is_sharded(self):
156        """ Not implemented. """
157        raise NotImplementedError("Sampler must implement is_sharded.")
158
159    def get_num_samples(self):
160        """
161        All samplers can contain a numeric num_samples value (or it can be set to None).
162        A child sampler can exist or be None.
163        If a child sampler exists, then the child sampler count can be a numeric value or None.
164        These conditions impact the resultant sampler count that is used.
165        The following table shows the possible results from calling this function.
166
167        .. list-table::
168           :widths: 25 25 25 25
169           :header-rows: 1
170
171           * - child sampler
172             - num_samples
173             - child_samples
174             - result
175           * - T
176             - x
177             - y
178             - min(x, y)
179           * - T
180             - x
181             - None
182             - x
183           * - T
184             - None
185             - y
186             - y
187           * - T
188             - None
189             - None
190             - None
191           * - None
192             - x
193             - n/a
194             - x
195           * - None
196             - None
197             - n/a
198             - None
199
200        Returns:
201            int, the number of samples, or None.
202
203        Examples:
204            >>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
205            >>> num_samplers = sampler.get_num_samples()
206        """
207        if self.child_sampler is not None:
208            child_samples = self.child_sampler.get_num_samples()
209            if self.num_samples is not None:
210                if child_samples is not None:
211                    return min(self.num_samples, child_samples)
212
213                return self.num_samples
214
215            return child_samples
216
217        return self.num_samples
218
219
220class Sampler(BuiltinSampler):
221    """
222    Base class for user defined sampler.
223    A user defined sampler can be used with any existing dataset with sampler support.
224
225    A required  _iter_() method should by overridden by the user for sample index generation.
226    An optional reset() method can be overridden for per repeat reset,
227
228    dataset_size and num_samples will be set by dataset once a dataset iterator is created.
229
230    Examples:
231        >>> class ReverseSampler(ds.Sampler):
232        ...     def __iter__(self):
233        ...         for i in range(self.dataset_size - 1, -1, -1):
234        ...             yield i
235        >>>
236        >>> ds = ds.ImageFolderDataset(image_folder_dataset_dir, sampler=ReverseSampler())
237    """
238
239    def __init__(self, num_samples=None):
240        super().__init__(num_samples)
241        self.dataset_size = 0
242        self.child_sampler = None
243        self.num_samples = num_samples
244
245    def __iter__(self):
246        """
247        User defined iterator, must be overridden.
248        _handshake is guaranteed to be called prior to iterator construction.
249        """
250        raise NotImplementedError
251
252    def reset(self):
253        """
254        Per repeat reset callback, override this method if necessary
255        """
256
257    # Initialization handshake callback
258    # Do not override this method!
259    def _handshake(self, ds_size, num_samples):
260        self.dataset_size = ds_size
261        self.num_samples = num_samples
262
263    # Indices fetcher
264    # Do not override this method!
265    # pylint: disable=missing-docstring
266    def _get_indices(self):
267        sampler_iter = iter(self)
268        ret = []
269        for _ in range(self.num_samples):
270            try:
271                idx = next(sampler_iter)
272                ret.append(idx)
273            except StopIteration:
274                break
275        indices = np.array(ret)
276        if indices.dtype == object:
277            raise RuntimeError("Fetched indices can not be converted to a valid ndarray.")
278        return indices
279
280    # Instance fetcher
281    # Do not override this method!
282    def parse(self):
283        """ Parse the sampler."""
284        num_samples = self.num_samples if self.num_samples is not None else 0
285        c_sampler = cde.PreBuiltSamplerObj(num_samples, self)
286        c_child_sampler = self.parse_child()
287        c_sampler.add_child(c_child_sampler)
288        return c_sampler
289
290    def add_child(self, sampler):
291        self.child_sampler = sampler
292
293    def get_child(self):
294        return self.child_sampler
295
296    def parse_child(self):
297        c_child_sampler = None
298        if self.child_sampler is not None:
299            c_child_sampler = self.child_sampler.parse()
300
301        return c_child_sampler
302
303    def is_shuffled(self):
304        if self.child_sampler is None:
305            return False
306
307        return self.child_sampler.is_shuffled()
308
309    def is_sharded(self):
310        if self.child_sampler is None:
311            return False
312
313        return self.child_sampler.is_sharded()
314
315    def get_num_samples(self):
316        if self.num_samples is None:
317            return None
318        return self._get_indices().size
319
320
321class DistributedSampler(BuiltinSampler):
322    """
323    A sampler that accesses a shard of the dataset, it helps divide dataset into multi-subset for distributed training.
324
325    Args:
326        num_shards (int): Number of shards to divide the dataset into.
327        shard_id (int): Shard ID of the current shard, which should within the range of [0, num_shards-1].
328        shuffle (bool, optional): If True, the indices are shuffled, otherwise it will not be shuffled(default=True).
329        num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements).
330        offset(int, optional): The starting shard ID where the elements in the dataset are sent to, which
331            should be no more than num_shards. This parameter is only valid when a ConcatDataset takes
332            a DistributedSampler as its sampler. It will affect the number of samples of per shard
333            (default=-1, which means each shard has same number of samples).
334
335    Examples:
336        >>> # creates a distributed sampler with 10 shards in total. This shard is shard 5.
337        >>> sampler = ds.DistributedSampler(10, 5)
338        >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
339        ...                                 num_parallel_workers=8,
340        ...                                 sampler=sampler)
341
342    Raises:
343        TypeError: If num_shards is not an integer value.
344        TypeError: If shard_id is not an integer value.
345        TypeError: If shuffle is not a boolean value.
346        TypeError: If num_samples is not an integer value.
347        TypeError: If offset is not an integer value.
348        ValueError: If num_samples is a negative value.
349        RuntimeError: If num_shards is not a positive value.
350        RuntimeError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards.
351        RuntimeError: If offset is greater than num_shards.
352    """
353
354    def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
355        if not isinstance(num_shards, int):
356            raise TypeError("num_shards must be integer but was: {}.".format(num_shards))
357
358        if not isinstance(shard_id, int):
359            raise TypeError("shard_id must be integer but was: {}.".format(shard_id))
360
361        if not isinstance(shuffle, bool):
362            raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle))
363
364        if num_samples is not None:
365            if not isinstance(num_samples, int):
366                raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
367            if num_samples < 0 or num_samples > validator.INT64_MAX:
368                raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
369                                 .format(0, validator.INT64_MAX))
370
371        if not isinstance(offset, int):
372            raise TypeError("offset must be integer but was: {}.".format(offset))
373
374        self.num_shards = num_shards
375        self.shard_id = shard_id
376        self.shuffle = shuffle
377        self.seed = 0
378        self.offset = offset
379        super().__init__(num_samples)
380
381    def parse(self):
382        """ Parse the sampler."""
383        num_samples = self.num_samples if self.num_samples is not None else 0
384        shuffle = self.shuffle if self.shuffle is not None else True
385        offset = self.offset if self.offset is not None else -1
386        # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
387        self.seed += 1
388        c_sampler = cde.DistributedSamplerObj(self.num_shards, self.shard_id,
389                                              shuffle, num_samples, self.seed, offset, True)
390        c_child_sampler = self.parse_child()
391        c_sampler.add_child(c_child_sampler)
392        return c_sampler
393
394    def parse_for_minddataset(self):
395        """ Parse the sampler for MindRecord."""
396        num_samples = self.num_samples if self.num_samples is not None else 0
397        shuffle = self.shuffle if self.shuffle is not None else True
398        c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, shuffle,
399                                                     self.seed, num_samples, self.offset)
400        c_child_sampler = self.parse_child_for_minddataset()
401        c_sampler.add_child(c_child_sampler)
402        return c_sampler
403
404    def is_shuffled(self):
405        if self.child_sampler is None:
406            return self.shuffle
407
408        return self.child_sampler.is_shuffled()
409
410    def is_sharded(self):
411        if self.child_sampler is None:
412            return self.num_shards > 1
413
414        return self.child_sampler.is_sharded()
415
416    def set_offset(self, offset):
417        self.offset = offset
418        return self
419
420
421class PKSampler(BuiltinSampler):
422    """
423    Samples K elements for each P class in the dataset.
424
425    Args:
426        num_val (int): Number of elements to sample for each class.
427        num_class (int, optional): Number of classes to sample (default=None, sample all classes).
428            The parameter does not supported to specify currently.
429        shuffle (bool, optional): If True, the class IDs are shuffled, otherwise it will not be
430            shuffled (default=False).
431        class_column (str, optional): Name of column with class labels for MindDataset (default='label').
432        num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements).
433
434    Examples:
435        >>> # creates a PKSampler that will get 3 samples from every class.
436        >>> sampler = ds.PKSampler(3)
437        >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
438        ...                                 num_parallel_workers=8,
439        ...                                 sampler=sampler)
440
441    Raises:
442        TypeError: If shuffle is not a boolean value.
443        TypeError: If class_column is not a str value.
444        TypeError: If num_samples is not an integer value.
445        NotImplementedError: If num_class is not None.
446        RuntimeError: If num_val is not a positive value.
447        ValueError: If num_samples is a negative value.
448    """
449
450    def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
451        if not isinstance(num_val, int):
452            raise TypeError("num_val must be integer but was: {}.".format(num_val))
453
454        if num_class is not None:
455            raise NotImplementedError("Not supported to specify num_class for PKSampler.")
456
457        if not isinstance(shuffle, bool):
458            raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle))
459
460        if not isinstance(class_column, str):
461            raise TypeError("class_column must be a str value but was: {}.".format(class_column))
462
463        if num_samples is not None:
464            if not isinstance(num_samples, int):
465                raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
466            if num_samples < 0 or num_samples > validator.INT64_MAX:
467                raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
468                                 .format(0, validator.INT64_MAX))
469
470        self.num_val = num_val
471        self.shuffle = shuffle
472        self.class_column = class_column  # work for minddataset
473        super().__init__(num_samples)
474
475    def parse(self):
476        """ Parse the sampler."""
477        num_samples = self.num_samples if self.num_samples is not None else 0
478        shuffle = self.shuffle if self.shuffle is not None else False
479        c_sampler = cde.PKSamplerObj(self.num_val, shuffle, num_samples)
480        c_child_sampler = self.parse_child()
481        c_sampler.add_child(c_child_sampler)
482        return c_sampler
483
484    def is_shuffled(self):
485        if self.child_sampler is None:
486            return self.shuffle
487
488        return self.child_sampler.is_shuffled()
489
490    def is_sharded(self):
491        if self.child_sampler is None:
492            return False
493
494        return self.child_sampler.is_sharded()
495
496    def parse_for_minddataset(self):
497        """Parse the sampler for MindRecord."""
498        if not self.class_column or not isinstance(self.class_column, str):
499            raise ValueError("class_column should be a not empty string value, \
500                    but got class_column: {}.".format(self.class_column))
501        num_samples = self.num_samples if self.num_samples is not None else 0
502        c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
503        c_child_sampler = self.parse_child_for_minddataset()
504        c_sampler.add_child(c_child_sampler)
505        return c_sampler
506
507
508class RandomSampler(BuiltinSampler):
509    """
510    Samples the elements randomly.
511
512    Args:
513        replacement (bool, optional): If True, put the sample ID back for the next draw (default=False).
514        num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
515
516    Examples:
517        >>> # creates a RandomSampler
518        >>> sampler = ds.RandomSampler()
519        >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
520        ...                                 num_parallel_workers=8,
521        ...                                 sampler=sampler)
522
523    Raises:
524        TypeError: If replacement is not a boolean value.
525        TypeError: If num_samples is not an integer value.
526        ValueError: If num_samples is a negative value.
527     """
528
529    def __init__(self, replacement=False, num_samples=None):
530        if not isinstance(replacement, bool):
531            raise TypeError("replacement must be a boolean value but was: {}.".format(replacement))
532
533        if num_samples is not None:
534            if not isinstance(num_samples, int):
535                raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
536            if num_samples < 0 or num_samples > validator.INT64_MAX:
537                raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
538                                 .format(0, validator.INT64_MAX))
539
540        self.deterministic = False
541        self.replacement = replacement
542        self.reshuffle_each_epoch = True
543        super().__init__(num_samples)
544
545    def parse(self):
546        """ Parse the sampler."""
547        num_samples = self.num_samples if self.num_samples is not None else 0
548        replacement = self.replacement if self.replacement is not None else False
549        c_sampler = cde.RandomSamplerObj(replacement, num_samples, self.reshuffle_each_epoch)
550        c_child_sampler = self.parse_child()
551        c_sampler.add_child(c_child_sampler)
552        return c_sampler
553
554    def parse_for_minddataset(self):
555        """Parse the sampler for MindRecord."""
556        num_samples = self.num_samples if self.num_samples is not None else 0
557        c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
558        c_child_sampler = self.parse_child_for_minddataset()
559        c_sampler.add_child(c_child_sampler)
560        return c_sampler
561
562    def is_shuffled(self):
563        return True
564
565    def is_sharded(self):
566        if self.child_sampler is None:
567            return False
568
569        return self.child_sampler.is_sharded()
570
571
572class SequentialSampler(BuiltinSampler):
573    """
574    Samples the dataset elements sequentially that is equivalent to not using a sampler.
575
576    Args:
577        start_index (int, optional): Index to start sampling at. (default=None, start at first ID)
578        num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
579
580    Examples:
581        >>> # creates a SequentialSampler
582        >>> sampler = ds.SequentialSampler()
583        >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
584        ...                                 num_parallel_workers=8,
585        ...                                 sampler=sampler)
586
587    Raises:
588        TypeError: If start_index is not an integer value.
589        TypeError: If num_samples is not an integer value.
590        RuntimeError: If start_index is a negative value.
591        ValueError: If num_samples is a negative value.
592    """
593
594    def __init__(self, start_index=None, num_samples=None):
595        if start_index is not None and not isinstance(start_index, int):
596            raise TypeError("start_index must be integer but was: {}.".format(start_index))
597
598        if num_samples is not None:
599            if not isinstance(num_samples, int):
600                raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
601            if num_samples < 0 or num_samples > validator.INT64_MAX:
602                raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
603                                 .format(0, validator.INT64_MAX))
604
605        self.start_index = start_index
606        super().__init__(num_samples)
607
608    def parse(self):
609        """ Parse the sampler."""
610        start_index = self.start_index if self.start_index is not None else 0
611        num_samples = self.num_samples if self.num_samples is not None else 0
612        c_sampler = cde.SequentialSamplerObj(start_index, num_samples)
613        c_child_sampler = self.parse_child()
614        c_sampler.add_child(c_child_sampler)
615        return c_sampler
616
617    def parse_for_minddataset(self):
618        """Parse the sampler for MindRecord."""
619        start_index = self.start_index if self.start_index is not None else 0
620        num_samples = self.num_samples if self.num_samples is not None else 0
621        c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
622        c_child_sampler = self.parse_child_for_minddataset()
623        c_sampler.add_child(c_child_sampler)
624        return c_sampler
625
626    def is_shuffled(self):
627        if self.child_sampler is None:
628            return False
629
630        return self.child_sampler.is_shuffled()
631
632    def is_sharded(self):
633        if self.child_sampler is None:
634            return False
635
636        return self.child_sampler.is_sharded()
637
638
639class SubsetSampler(BuiltinSampler):
640    """
641    Samples the elements from a sequence of indices.
642
643    Args:
644        indices (Any iterable Python object but string): A sequence of indices.
645        num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
646
647    Examples:
648        >>> indices = [0, 1, 2, 3, 4, 5]
649        >>>
650        >>> # creates a SubsetSampler, will sample from the provided indices
651        >>> sampler = ds.SubsetSampler(indices)
652        >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
653        ...                                 num_parallel_workers=8,
654        ...                                 sampler=sampler)
655
656    Raises:
657        TypeError: If type of indices element is not a number.
658        TypeError: If num_samples is not an integer value.
659        ValueError: If num_samples is a negative value.
660    """
661
662    def __init__(self, indices, num_samples=None):
663        def _get_sample_ids_as_list(sampler, number_of_samples=None):
664            if number_of_samples is None:
665                return list(sampler)
666
667            if isinstance(sampler, list):
668                return sampler[:number_of_samples]
669
670            return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
671
672        if num_samples is not None:
673            if not isinstance(num_samples, int):
674                raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
675            if num_samples < 0 or num_samples > validator.INT64_MAX:
676                raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
677                                 .format(0, validator.INT64_MAX))
678
679        if not isinstance(indices, str) and validator.is_iterable(indices):
680            indices = _get_sample_ids_as_list(indices, num_samples)
681        elif isinstance(indices, int):
682            indices = [indices]
683        else:
684            raise TypeError('Unsupported sampler object of type ({})'.format(type(indices)))
685
686        for i, item in enumerate(indices):
687            if not isinstance(item, (int, np.integer)):
688                raise TypeError("SubsetSampler: Type of indices element must be int, "
689                                "but got list[{}]: {}, type: {}.".format(i, item, type(item)))
690
691        self.indices = indices
692        super().__init__(num_samples)
693
694    def parse(self):
695        """ Parse the sampler."""
696        num_samples = self.num_samples if self.num_samples is not None else 0
697        c_sampler = cde.SubsetSamplerObj(self.indices, num_samples)
698        c_child_sampler = self.parse_child()
699        c_sampler.add_child(c_child_sampler)
700        return c_sampler
701
702    def is_shuffled(self):
703        return False
704
705    def is_sharded(self):
706        if self.child_sampler is None:
707            return False
708
709        return self.child_sampler.is_sharded()
710
711    def parse_for_minddataset(self):
712        """Parse the sampler for MindRecord."""
713        c_sampler = cde.MindrecordSubsetSampler(self.indices)
714        c_child_sampler = self.parse_child_for_minddataset()
715        c_sampler.add_child(c_child_sampler)
716        return c_sampler
717
718    def get_num_samples(self):
719        num_samples = super().get_num_samples()
720        if num_samples is None:
721            return len(self.indices)
722
723        return min(len(self.indices), num_samples)
724
725
726class SubsetRandomSampler(SubsetSampler):
727    """
728    Samples the elements randomly from a sequence of indices.
729
730    Args:
731        indices (Any iterable Python object but string): A sequence of indices.
732        num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
733
734    Examples:
735        >>> indices = [0, 1, 2, 3, 7, 88, 119]
736        >>>
737        >>> # create a SubsetRandomSampler, will sample from the provided indices
738        >>> sampler = ds.SubsetRandomSampler(indices)
739        >>> data = ds.ImageFolderDataset(image_folder_dataset_dir, num_parallel_workers=8, sampler=sampler)
740
741    Raises:
742        TypeError: If type of indices element is not a number.
743        TypeError: If num_samples is not an integer value.
744        ValueError: If num_samples is a negative value.
745    """
746
747    def parse(self):
748        """ Parse the sampler."""
749        num_samples = self.num_samples if self.num_samples is not None else 0
750        c_sampler = cde.SubsetRandomSamplerObj(self.indices, num_samples)
751        c_child_sampler = self.parse_child()
752        c_sampler.add_child(c_child_sampler)
753        return c_sampler
754
755    def is_shuffled(self):
756        return True
757
758    def parse_for_minddataset(self):
759        """Parse the sampler for MindRecord."""
760        c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed())
761        c_child_sampler = self.parse_child_for_minddataset()
762        c_sampler.add_child(c_child_sampler)
763        return c_sampler
764
765
766class IterSampler(Sampler):
767    """
768    User provided an iterable object without inheriting from our Sampler class.
769
770    Note:
771        This class exists to allow handshake logic between dataset operators and user defined samplers.
772        By constructing this object we avoid the user having to inherit from our Sampler class.
773
774    Args:
775        sampler (iterable object): an user defined iterable object.
776        num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
777
778    Examples:
779        >>> class MySampler:
780        ...     def __iter__(self):
781        ...         for i in range(99, -1, -1):
782        ...             yield i
783
784        >>> # creates an IterSampler
785        >>> sampler = ds.IterSampler(sampler=MySampler())
786        >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
787        ...                                 num_parallel_workers=8,
788        ...                                 sampler=sampler)
789     """
790
791    def __init__(self, sampler, num_samples=None):
792        if num_samples is None:
793            num_samples = len(list(sampler))
794        super().__init__(num_samples=num_samples)
795        self.sampler = sampler
796
797    def __iter__(self):
798        return iter(self.sampler)
799
800
801class WeightedRandomSampler(BuiltinSampler):
802    """
803    Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
804
805    Args:
806        weights (list[float, int]): A sequence of weights, not necessarily summing up to 1.
807        num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
808        replacement (bool): If True, put the sample ID back for the next draw (default=True).
809
810    Examples:
811        >>> weights = [0.9, 0.01, 0.4, 0.8, 0.1, 0.1, 0.3]
812        >>>
813        >>> # creates a WeightedRandomSampler that will sample 4 elements without replacement
814        >>> sampler = ds.WeightedRandomSampler(weights, 4)
815        >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
816        ...                                 num_parallel_workers=8,
817        ...                                 sampler=sampler)
818
819    Raises:
820        TypeError: If type of weights element is not a number.
821        TypeError: If num_samples is not an integer value.
822        TypeError: If replacement is not a boolean value.
823        RuntimeError: If `weights` is empty or all zero.
824        ValueError: If num_samples is a negative value.
825    """
826
827    def __init__(self, weights, num_samples=None, replacement=True):
828        if not isinstance(weights, list):
829            weights = [weights]
830
831        for ind, w in enumerate(weights):
832            if not isinstance(w, numbers.Number):
833                raise TypeError("type of weights element must be number, "
834                                "but got w[{}]: {}, type: {}.".format(ind, w, type(w)))
835
836        if num_samples is not None:
837            if not isinstance(num_samples, int):
838                raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
839            if num_samples < 0 or num_samples > validator.INT64_MAX:
840                raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
841                                 .format(0, validator.INT64_MAX))
842
843        if not isinstance(replacement, bool):
844            raise TypeError("replacement must be a boolean value but was: {}.".format(replacement))
845
846        self.weights = weights
847        self.replacement = replacement
848        super().__init__(num_samples)
849
850    def parse(self):
851        """ Parse the sampler."""
852        num_samples = self.num_samples if self.num_samples is not None else 0
853        replacement = self.replacement if self.replacement is not None else True
854        c_sampler = cde.WeightedRandomSamplerObj(self.weights, num_samples, replacement)
855        c_child_sampler = self.parse_child()
856        c_sampler.add_child(c_child_sampler)
857        return c_sampler
858
859    def is_shuffled(self):
860        return True
861
862    def is_sharded(self):
863        if self.child_sampler is None:
864            return False
865
866        return self.child_sampler.is_sharded()
867