• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16The sampler module provides several samplers to generate data from datasets.
17The provided samplers include: DistributedSampler, PKSampler, RandomSampler,
18SequentialSampler, SubsetRandomSampler, and WeightedRandomSampler.
19Users can also define a custom sampler by extending from the Sampler class.
20"""
21
22import numbers
23import numpy as np
24import mindspore._c_dataengine as cde
25import mindspore.dataset as ds
26from ..core import validator_helpers as validator
27
28
29def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
30    """
31    Create sampler based on user input.
32
33    Args:
34        num_samples (int): Number of samples.
35        input_sampler (Union[Iterable, Sampler]): Sampler from user.
36        shuffle (bool): Shuffle.
37        num_shards (int): Number of shard for sharding.
38        shard_id (int): Shard ID.
39
40    Returns:
41        Sampler, sampler selected based on user input.
42    """
43
44    if input_sampler is not None:
45        # If the user provided a sampler, then it doesn't matter what the other args are because
46        # we are being asked specifically to use the given sampler.
47        # That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
48        # be None. Consider this example:
49        #     sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
50        #     data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
51        # In this case, the user has given different sample-related arguments that contradict each other.
52        # To prevent this, only allow the user to manually specify the sampler if those arguments are all None
53        if (isinstance(input_sampler, BuiltinSampler) and
54                (any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
55            raise ValueError(
56                'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
57                ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
58        if isinstance(input_sampler, BuiltinSampler):
59            return input_sampler
60        if not isinstance(input_sampler, str) and isinstance(input_sampler, (np.ndarray, list, tuple)):
61            return SubsetSampler(input_sampler, num_samples)
62        if not isinstance(input_sampler, str) and validator.is_iterable(input_sampler):
63            # in this case, the user passed in their own sampler object that's not of type BuiltinSampler
64            return IterSampler(input_sampler, num_samples)
65        if isinstance(input_sampler, int):
66            return SubsetSampler([input_sampler])
67        raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
68    if shuffle is None:
69        if num_shards is not None:
70            # If shuffle is not specified, sharding enabled, use distributed random sampler
71            shuffle = True
72            return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
73        # If shuffle is not specified, sharding disabled, use random sampler
74        if num_samples is not None and num_samples != 0:
75            return RandomSampler(replacement=True, num_samples=num_samples)
76        return RandomSampler(num_samples=num_samples)
77    if shuffle is True:
78        if num_shards is not None:
79            # If shuffle enabled, sharding enabled, use distributed random sampler
80            return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
81        # If shuffle enabled, sharding disabled, use random sampler
82        if num_samples is not None:
83            return RandomSampler(replacement=True, num_samples=num_samples)
84        return RandomSampler(num_samples=num_samples)
85    if num_shards is not None:
86        # If shuffle disabled, sharding enabled, use distributed sequential sampler
87        return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
88    # If shuffle disabled, sharding disabled, use sequential sampler
89    return SequentialSampler(num_samples=num_samples)
90
91
92class BuiltinSampler:
93    """
94    Base class for BuiltinSampler.
95
96    User should not extend this class.
97    """
98
99    def __init__(self, num_samples=None):
100        self.child_sampler = None
101        self.num_samples = num_samples
102
103    def parse(self):
104        """ Parse the sampler."""
105
106    def add_child(self, sampler):
107        """
108        Add a sub-sampler for given sampler. The parent will receive all data from the
109        output of sub-sampler sampler and apply its sample logic to return new samples.
110
111        Args:
112            sampler (Sampler): Object used to choose samples from the dataset. Only builtin
113                samplers(:class:`mindspore.dataset.DistributedSampler` ,
114                :class:`mindspore.dataset.PKSampler`,
115                :class:`mindspore.dataset.RandomSampler`,
116                :class:`mindspore.dataset.SequentialSampler`,
117                :class:`mindspore.dataset.SubsetRandomSampler`,
118                :class:`mindspore.dataset.WeightedRandomSampler` ) are supported.
119
120        Examples:
121            >>> import mindspore.dataset as ds
122            >>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
123            >>> sampler.add_child(ds.RandomSampler(num_samples=4))
124            >>> dataset = ds.Cifar10Dataset(cifar10_dataset_dir, sampler=sampler)
125        """
126        if self.child_sampler is not None:
127            raise RuntimeError("Cannot add child sampler, this sampler already has a child.")
128        self.child_sampler = sampler
129
130    def get_child(self):
131        """
132        Get the child sampler of given sampler.
133
134        Returns:
135            Sampler, The child sampler of given sampler.
136
137        Examples:
138            >>> import mindspore.dataset as ds
139            >>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
140            >>> sampler.add_child(ds.RandomSampler(num_samples=2))
141            >>> child_sampler = sampler.get_child()
142        """
143        return self.child_sampler
144
145    def parse_child(self):
146        """ Parse the child sampler. """
147        c_child_sampler = None
148        if self.child_sampler is not None:
149            c_child_sampler = self.child_sampler.parse()
150        return c_child_sampler
151
152    def parse_child_for_minddataset(self):
153        """ Parse the child sampler for MindRecord. """
154        c_child_sampler = None
155        if self.child_sampler is not None:
156            c_child_sampler = self.child_sampler.parse_for_minddataset()
157        return c_child_sampler
158
159    def is_shuffled(self):
160        """ Not implemented. """
161        raise NotImplementedError("Sampler must implement is_shuffled.")
162
163    def is_sharded(self):
164        """ Not implemented. """
165        raise NotImplementedError("Sampler must implement is_sharded.")
166
167    def get_num_samples(self):
168        """
169        Get `num_samples` value of the current sampler instance.
170        This parameter can be optionally passed in when defining the Sampler. Default: ``None``.
171        This method will return the num_samples value.
172        If the current sampler has child samplers,
173        it will continue to access the child samplers and process the obtained value according to certain rules.
174
175        The following table shows the various possible combinations, and the final results returned.
176
177        .. list-table::
178           :widths: 25 25 25 25
179           :header-rows: 1
180
181           * - child sampler
182             - num_samples
183             - child_samples
184             - result
185           * - T
186             - x
187             - y
188             - min(x, y)
189           * - T
190             - x
191             - None
192             - x
193           * - T
194             - None
195             - y
196             - y
197           * - T
198             - None
199             - None
200             - None
201           * - None
202             - x
203             - n/a
204             - x
205           * - None
206             - None
207             - n/a
208             - None
209
210        Returns:
211            int, the number of samples, or None.
212
213        Examples:
214            >>> import mindspore.dataset as ds
215            >>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
216            >>> num_samplers = sampler.get_num_samples()
217        """
218        if self.child_sampler is not None:
219            child_samples = self.child_sampler.get_num_samples()
220            if self.num_samples is not None:
221                if child_samples is not None:
222                    return min(self.num_samples, child_samples)
223
224                return self.num_samples
225
226            return child_samples
227
228        return self.num_samples
229
230
231class Sampler(BuiltinSampler):
232    """
233    Base class for user defined sampler.
234    A user defined sampler can be used with any existing dataset with sampler support.
235
236    A required  _iter_() method should by overridden by the user for sample index generation.
237    An optional reset() method can be overridden for per repeat reset,
238
239    dataset_size and num_samples will be set by dataset once a dataset iterator is created.
240
241    Examples:
242        >>> import mindspore.dataset as ds
243        >>> class ReverseSampler(ds.Sampler):
244        ...     def __iter__(self):
245        ...         for i in range(self.dataset_size - 1, -1, -1):
246        ...             yield i
247        >>>
248        >>> ds = ds.ImageFolderDataset(image_folder_dataset_dir, sampler=ReverseSampler())
249    """
250
251    def __init__(self, num_samples=None):
252        super().__init__(num_samples)
253        self.dataset_size = 0
254        self.child_sampler = None
255        self.num_samples = num_samples
256
257    def __iter__(self):
258        """
259        User defined iterator, must be overridden.
260        _handshake is guaranteed to be called prior to iterator construction.
261        """
262        raise NotImplementedError
263
264    def reset(self):
265        """
266        Per repeat reset callback, override this method if necessary
267        """
268
269    # Initialization handshake callback
270    # Do not override this method!
271    def _handshake(self, ds_size, num_samples):
272        self.dataset_size = ds_size
273        self.num_samples = num_samples
274
275    # Indices fetcher
276    # Do not override this method!
277    # pylint: disable=missing-docstring
278    def _get_indices(self):
279        sampler_iter = iter(self)
280        ret = []
281        for _ in range(self.num_samples):
282            try:
283                idx = next(sampler_iter)
284                ret.append(idx)
285            except StopIteration:
286                break
287        indices = np.array(ret)
288        if indices.dtype == object:
289            raise RuntimeError("Fetched indices can not be converted to a valid ndarray.")
290        return indices
291
292    # Instance fetcher
293    # Do not override this method!
294    def parse(self):
295        """ Parse the sampler."""
296        num_samples = self.num_samples if self.num_samples is not None else 0
297        c_sampler = cde.PreBuiltSamplerObj(num_samples, self)
298        c_child_sampler = self.parse_child()
299        c_sampler.add_child(c_child_sampler)
300        return c_sampler
301
302    def add_child(self, sampler):
303        self.child_sampler = sampler
304
305    def get_child(self):
306        return self.child_sampler
307
308    def parse_child(self):
309        c_child_sampler = None
310        if self.child_sampler is not None:
311            c_child_sampler = self.child_sampler.parse()
312
313        return c_child_sampler
314
315    def is_shuffled(self):
316        if self.child_sampler is None:
317            return False
318
319        return self.child_sampler.is_shuffled()
320
321    def is_sharded(self):
322        if self.child_sampler is None:
323            return False
324
325        return self.child_sampler.is_sharded()
326
327    def get_num_samples(self):
328        if self.num_samples is None:
329            return None
330        return self._get_indices().size
331
332
333class DistributedSampler(BuiltinSampler):
334    """
335    A sampler that accesses a shard of the dataset, it helps divide dataset into multi-subset for distributed training.
336
337    Args:
338        num_shards (int): Number of shards to divide the dataset into.
339        shard_id (int): Shard ID of the current shard, which should within the range of [0, `num_shards` - 1].
340        shuffle (bool, optional): If True, the indices are shuffled, otherwise it will not be shuffled.
341            Default: ``True``.
342        num_samples (int, optional): The number of samples to draw. Default: ``None``, which means sample all elements.
343        offset(int, optional): The starting shard ID where the elements in the dataset are sent to, which
344            should be no more than `num_shards` . This parameter is only valid when a ConcatDataset takes
345            a :class:`mindspore.dataset.DistributedSampler` as its sampler. It will affect the number of
346            samples of per shard. Default: ``-1``, which means each shard has the same number of samples.
347
348    Raises:
349        TypeError: If `num_shards` is not of type int.
350        TypeError: If `shard_id` is not of type int.
351        TypeError: If `shuffle` is not of type bool.
352        TypeError: If `num_samples` is not of type int.
353        TypeError: If `offset` is not of type int.
354        ValueError: If `num_samples` is a negative value.
355        RuntimeError: If `num_shards` is not a positive value.
356        RuntimeError: If `shard_id` is smaller than 0 or equal to `num_shards` or larger than `num_shards` .
357        RuntimeError: If `offset` is greater than `num_shards` .
358
359    Examples:
360        >>> import mindspore.dataset as ds
361        >>> # creates a distributed sampler with 10 shards in total. This shard is shard 5.
362        >>> sampler = ds.DistributedSampler(10, 5)
363        >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
364        ...                                 num_parallel_workers=8,
365        ...                                 sampler=sampler)
366    """
367
368    def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
369        if not isinstance(num_shards, int):
370            raise TypeError("num_shards must be integer but was: {}.".format(num_shards))
371
372        if not isinstance(shard_id, int):
373            raise TypeError("shard_id must be integer but was: {}.".format(shard_id))
374
375        if not isinstance(shuffle, bool):
376            raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle))
377
378        if num_samples is not None:
379            if not isinstance(num_samples, int):
380                raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
381            if num_samples < 0 or num_samples > validator.INT64_MAX:
382                raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
383                                 .format(0, validator.INT64_MAX))
384
385        if not isinstance(offset, int):
386            raise TypeError("offset must be integer but was: {}.".format(offset))
387
388        self.num_shards = num_shards
389        self.shard_id = shard_id
390        self.shuffle = shuffle
391        # get seed in distributed scenario
392        # Example 1. if user set seeds by ds.config.set_seed(4321), then seed 4321 is used
393        # Example 2. if user does not set the seed, then existing or default seed (like 5489) is used
394        self.seed = ds.config.get_seed()
395        self.offset = offset
396        super().__init__(num_samples)
397
398    def parse(self):
399        """ Parse the sampler."""
400        num_samples = self.num_samples if self.num_samples is not None else 0
401        shuffle = self.shuffle if self.shuffle is not None else True
402        offset = self.offset if self.offset is not None else -1
403        # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
404        self.seed += 1
405        c_sampler = cde.DistributedSamplerObj(self.num_shards, self.shard_id,
406                                              shuffle, num_samples, self.seed, offset, True)
407        c_child_sampler = self.parse_child()
408        c_sampler.add_child(c_child_sampler)
409        return c_sampler
410
411    def parse_for_minddataset(self):
412        """ Parse the sampler for MindRecord."""
413        num_samples = self.num_samples if self.num_samples is not None else 0
414        shuffle = self.shuffle if self.shuffle is not None else True
415        c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, shuffle,
416                                                     self.seed, num_samples, self.offset)
417        c_child_sampler = self.parse_child_for_minddataset()
418        c_sampler.add_child(c_child_sampler)
419        c_sampler.set_num_samples(num_samples)
420        return c_sampler
421
422    def is_shuffled(self):
423        if self.child_sampler is None:
424            return self.shuffle
425
426        return self.child_sampler.is_shuffled()
427
428    def is_sharded(self):
429        if self.child_sampler is None:
430            return self.num_shards > 1
431
432        return self.child_sampler.is_sharded()
433
434    def set_offset(self, offset):
435        self.offset = offset
436        return self
437
438
439class PKSampler(BuiltinSampler):
440    """
441    Samples K elements for each P class in the dataset.
442
443    Args:
444        num_val (int): Number of elements to sample for each class.
445        num_class (int, optional): Number of classes to sample. Default: ``None`` , sample all classes.
446            The parameter does not support to specify currently.
447        shuffle (bool, optional): Whether to shuffle the class IDs. Default: ``False``.
448        class_column (str, optional): Name of column with class labels for MindDataset. Default: ``'label'``.
449        num_samples (int, optional): The number of samples to draw. Default: ``None`` , which means sample all elements.
450
451    Raises:
452        TypeError: If `shuffle` is not of type bool.
453        TypeError: If `class_column` is not of type str.
454        TypeError: If `num_samples` is not of type int.
455        NotImplementedError: If `num_class` is not ``None``.
456        RuntimeError: If `num_val` is not a positive value.
457        ValueError: If `num_samples` is a negative value.
458
459    Examples:
460        >>> import mindspore.dataset as ds
461        >>> # creates a PKSampler that will get 3 samples from every class.
462        >>> sampler = ds.PKSampler(3)
463        >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
464        ...                                 num_parallel_workers=8,
465        ...                                 sampler=sampler)
466    """
467
468    def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
469        if not isinstance(num_val, int):
470            raise TypeError("num_val must be integer but was: {}.".format(num_val))
471
472        if num_class is not None:
473            raise NotImplementedError("Not supported to specify num_class for PKSampler.")
474
475        if not isinstance(shuffle, bool):
476            raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle))
477
478        if not isinstance(class_column, str):
479            raise TypeError("class_column must be a str value but was: {}.".format(class_column))
480
481        if num_samples is not None:
482            if not isinstance(num_samples, int):
483                raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
484            if num_samples < 0 or num_samples > validator.INT64_MAX:
485                raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
486                                 .format(0, validator.INT64_MAX))
487
488        self.num_val = num_val
489        self.shuffle = shuffle
490        self.class_column = class_column  # work for minddataset
491        super().__init__(num_samples)
492
493    def parse(self):
494        """ Parse the sampler."""
495        num_samples = self.num_samples if self.num_samples is not None else 0
496        shuffle = self.shuffle if self.shuffle is not None else False
497        c_sampler = cde.PKSamplerObj(self.num_val, shuffle, num_samples)
498        c_child_sampler = self.parse_child()
499        c_sampler.add_child(c_child_sampler)
500        return c_sampler
501
502    def is_shuffled(self):
503        if self.child_sampler is None:
504            return self.shuffle
505
506        return self.child_sampler.is_shuffled()
507
508    def is_sharded(self):
509        if self.child_sampler is None:
510            return False
511
512        return self.child_sampler.is_sharded()
513
514    def parse_for_minddataset(self):
515        """Parse the sampler for MindRecord."""
516        if not self.class_column or not isinstance(self.class_column, str):
517            raise ValueError("class_column should be a not empty string value, \
518                    but got class_column: {}.".format(self.class_column))
519        num_samples = self.num_samples if self.num_samples is not None else 0
520        c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
521        c_child_sampler = self.parse_child_for_minddataset()
522        c_sampler.add_child(c_child_sampler)
523        c_sampler.set_num_samples(num_samples)
524        return c_sampler
525
526
527class RandomSampler(BuiltinSampler):
528    """
529    Samples the elements randomly.
530
531    Args:
532        replacement (bool, optional): If True, put the sample ID back for the next draw. Default: ``False``.
533        num_samples (int, optional): Number of elements to sample. Default: ``None`` , which means sample all elements.
534
535    Raises:
536        TypeError: If `replacement` is not of type bool.
537        TypeError: If `num_samples` is not of type int.
538        ValueError: If `num_samples` is a negative value.
539
540    Examples:
541        >>> import mindspore.dataset as ds
542        >>> # creates a RandomSampler
543        >>> sampler = ds.RandomSampler()
544        >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
545        ...                                 num_parallel_workers=8,
546        ...                                 sampler=sampler)
547     """
548
549    def __init__(self, replacement=False, num_samples=None):
550        if not isinstance(replacement, bool):
551            raise TypeError("replacement must be a boolean value but was: {}.".format(replacement))
552
553        if num_samples is not None:
554            if not isinstance(num_samples, int):
555                raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
556            if num_samples < 0 or num_samples > validator.INT64_MAX:
557                raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
558                                 .format(0, validator.INT64_MAX))
559
560        self.deterministic = False
561        self.replacement = replacement
562        self.reshuffle_each_epoch = True
563        super().__init__(num_samples)
564
565    def parse(self):
566        """ Parse the sampler."""
567        num_samples = self.num_samples if self.num_samples is not None else 0
568        replacement = self.replacement if self.replacement is not None else False
569        c_sampler = cde.RandomSamplerObj(replacement, num_samples, self.reshuffle_each_epoch)
570        c_child_sampler = self.parse_child()
571        c_sampler.add_child(c_child_sampler)
572        return c_sampler
573
574    def parse_for_minddataset(self):
575        """Parse the sampler for MindRecord."""
576        num_samples = self.num_samples if self.num_samples is not None else 0
577        c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
578        c_child_sampler = self.parse_child_for_minddataset()
579        c_sampler.add_child(c_child_sampler)
580        c_sampler.set_num_samples(num_samples)
581        return c_sampler
582
583    def is_shuffled(self):
584        return True
585
586    def is_sharded(self):
587        if self.child_sampler is None:
588            return False
589
590        return self.child_sampler.is_sharded()
591
592
593class SequentialSampler(BuiltinSampler):
594    """
595    Samples the dataset elements sequentially that is equivalent to not using a sampler.
596
597    Args:
598        start_index (int, optional): Index to start sampling at. Default: ``None`` , start at first ID.
599        num_samples (int, optional): Number of elements to sample. Default: ``None`` , which means sample all elements.
600
601    Raises:
602        TypeError: If `start_index` is not of type int.
603        TypeError: If `num_samples` is not of type int.
604        RuntimeError: If `start_index` is a negative value.
605        ValueError: If `num_samples` is a negative value.
606
607    Examples:
608        >>> import mindspore.dataset as ds
609        >>> # creates a SequentialSampler
610        >>> sampler = ds.SequentialSampler()
611        >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
612        ...                                 num_parallel_workers=8,
613        ...                                 sampler=sampler)
614    """
615
616    def __init__(self, start_index=None, num_samples=None):
617        if start_index is not None and not isinstance(start_index, int):
618            raise TypeError("start_index must be integer but was: {}.".format(start_index))
619
620        if num_samples is not None:
621            if not isinstance(num_samples, int):
622                raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
623            if num_samples < 0 or num_samples > validator.INT64_MAX:
624                raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
625                                 .format(0, validator.INT64_MAX))
626
627        self.start_index = start_index
628        super().__init__(num_samples)
629
630    def parse(self):
631        """ Parse the sampler."""
632        start_index = self.start_index if self.start_index is not None else 0
633        num_samples = self.num_samples if self.num_samples is not None else 0
634        c_sampler = cde.SequentialSamplerObj(start_index, num_samples)
635        c_child_sampler = self.parse_child()
636        c_sampler.add_child(c_child_sampler)
637        return c_sampler
638
639    def parse_for_minddataset(self):
640        """Parse the sampler for MindRecord."""
641        start_index = self.start_index if self.start_index is not None else 0
642        num_samples = self.num_samples if self.num_samples is not None else 0
643        c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
644        c_child_sampler = self.parse_child_for_minddataset()
645        c_sampler.add_child(c_child_sampler)
646        c_sampler.set_num_samples(num_samples)
647        return c_sampler
648
649    def is_shuffled(self):
650        if self.child_sampler is None:
651            return False
652
653        return self.child_sampler.is_shuffled()
654
655    def is_sharded(self):
656        if self.child_sampler is None:
657            return False
658
659        return self.child_sampler.is_sharded()
660
661
662class SubsetSampler(BuiltinSampler):
663    """
664    Samples the elements from a sequence of indices.
665
666    Args:
667        indices (Iterable): A sequence of indices (Any iterable Python object but string).
668        num_samples (int, optional): Number of elements to sample. Default: ``None`` , which means sample all elements.
669
670    Raises:
671        TypeError: If elements of `indices` are not of type number.
672        TypeError: If `num_samples` is not of type int.
673        ValueError: If `num_samples` is a negative value.
674
675    Examples:
676        >>> import mindspore.dataset as ds
677        >>> indices = [0, 1, 2, 3, 4, 5]
678        >>>
679        >>> # creates a SubsetSampler, will sample from the provided indices
680        >>> sampler = ds.SubsetSampler(indices)
681        >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
682        ...                                 num_parallel_workers=8,
683        ...                                 sampler=sampler)
684    """
685
686    def __init__(self, indices, num_samples=None):
687        def _get_sample_ids_as_list(sampler, number_of_samples=None):
688            if number_of_samples is None:
689                return list(sampler)
690
691            if isinstance(sampler, list):
692                return sampler[:number_of_samples]
693
694            return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
695
696        if num_samples is not None:
697            if not isinstance(num_samples, int):
698                raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
699            if num_samples < 0 or num_samples > validator.INT64_MAX:
700                raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
701                                 .format(0, validator.INT64_MAX))
702
703        if not isinstance(indices, str) and validator.is_iterable(indices):
704            indices = _get_sample_ids_as_list(indices, num_samples)
705        elif isinstance(indices, int):
706            indices = [indices]
707        else:
708            raise TypeError('Unsupported sampler object of type ({})'.format(type(indices)))
709
710        for i, item in enumerate(indices):
711            if not isinstance(item, (int, np.integer)):
712                raise TypeError("SubsetSampler: Type of indices element must be int, "
713                                "but got list[{}]: {}, type: {}.".format(i, item, type(item)))
714
715        self.indices = indices
716        super().__init__(num_samples)
717
718    def parse(self):
719        """ Parse the sampler."""
720        num_samples = self.num_samples if self.num_samples is not None else 0
721        c_sampler = cde.SubsetSamplerObj(self.indices, num_samples)
722        c_child_sampler = self.parse_child()
723        c_sampler.add_child(c_child_sampler)
724        return c_sampler
725
726    def is_shuffled(self):
727        return False
728
729    def is_sharded(self):
730        if self.child_sampler is None:
731            return False
732
733        return self.child_sampler.is_sharded()
734
735    def parse_for_minddataset(self):
736        """Parse the sampler for MindRecord."""
737        c_sampler = cde.MindrecordSubsetSampler(self.indices)
738        c_child_sampler = self.parse_child_for_minddataset()
739        c_sampler.add_child(c_child_sampler)
740        c_sampler.set_num_samples(self.get_num_samples())
741        return c_sampler
742
743    def get_num_samples(self):
744        num_samples = super().get_num_samples()
745        if num_samples is None:
746            return len(self.indices)
747
748        return min(len(self.indices), num_samples)
749
750
751class SubsetRandomSampler(SubsetSampler):
752    """
753    Samples the elements randomly from a sequence of indices.
754
755    Args:
756        indices (Iterable): A sequence of indices (Any iterable Python object but string).
757        num_samples (int, optional): Number of elements to sample. Default: ``None`` , which means sample all elements.
758
759    Raises:
760        TypeError: If elements of `indices` are not of type number.
761        TypeError: If `num_samples` is not of type int.
762        ValueError: If `num_samples` is a negative value.
763
764    Examples:
765        >>> import mindspore.dataset as ds
766        >>> indices = [0, 1, 2, 3, 7, 88, 119]
767        >>>
768        >>> # create a SubsetRandomSampler, will sample from the provided indices
769        >>> sampler = ds.SubsetRandomSampler(indices)
770        >>> data = ds.ImageFolderDataset(image_folder_dataset_dir, num_parallel_workers=8, sampler=sampler)
771    """
772
773    def parse(self):
774        """ Parse the sampler."""
775        num_samples = self.num_samples if self.num_samples is not None else 0
776        c_sampler = cde.SubsetRandomSamplerObj(self.indices, num_samples)
777        c_child_sampler = self.parse_child()
778        c_sampler.add_child(c_child_sampler)
779        return c_sampler
780
781    def is_shuffled(self):
782        return True
783
784    def parse_for_minddataset(self):
785        """Parse the sampler for MindRecord."""
786        c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed())
787        c_child_sampler = self.parse_child_for_minddataset()
788        c_sampler.add_child(c_child_sampler)
789        c_sampler.set_num_samples(self.get_num_samples())
790        return c_sampler
791
792
793class IterSampler(Sampler):
794    """
795    User provided an iterable object without inheriting from our Sampler class.
796
797    Note:
798        This class exists to allow handshake logic between dataset operations and user defined samplers.
799        By constructing this object we avoid the user having to inherit from our Sampler class.
800
801    Args:
802        sampler (iterable object): an user defined iterable object.
803        num_samples (int, optional): Number of elements to sample. Default: ``None`` , which means sample all elements.
804
805    Examples:
806        >>> import mindspore.dataset as ds
807        >>> class MySampler:
808        ...     def __iter__(self):
809        ...         for i in range(99, -1, -1):
810        ...             yield i
811
812        >>> # creates an IterSampler
813        >>> sampler = ds.IterSampler(sampler=MySampler())
814        >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
815        ...                                 num_parallel_workers=8,
816        ...                                 sampler=sampler)
817     """
818
819    def __init__(self, sampler, num_samples=None):
820        if num_samples is None:
821            num_samples = len(list(sampler))
822        super().__init__(num_samples=num_samples)
823        self.sampler = sampler
824
825    def __iter__(self):
826        return iter(self.sampler)
827
828
829class WeightedRandomSampler(BuiltinSampler):
830    """
831    Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
832
833    Args:
834        weights (list[float, int]): A sequence of weights, not necessarily summing up to 1.
835        num_samples (int, optional): Number of elements to sample. Default: ``None`` , which means sample all elements.
836        replacement (bool): If ``True``, put the sample ID back for the next draw. Default: ``True``.
837
838    Raises:
839        TypeError: If elements of `weights` are not of type number.
840        TypeError: If `num_samples` is not of type int.
841        TypeError: If `replacement` is not of type bool.
842        RuntimeError: If `weights` is empty or all zero.
843        ValueError: If `num_samples` is a negative value.
844
845    Examples:
846        >>> import mindspore.dataset as ds
847        >>> weights = [0.9, 0.01, 0.4, 0.8, 0.1, 0.1, 0.3]
848        >>>
849        >>> # creates a WeightedRandomSampler that will sample 4 elements without replacement
850        >>> sampler = ds.WeightedRandomSampler(weights, 4)
851        >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
852        ...                                 num_parallel_workers=8,
853        ...                                 sampler=sampler)
854    """
855
856    def __init__(self, weights, num_samples=None, replacement=True):
857        if not isinstance(weights, list):
858            weights = [weights]
859
860        for ind, w in enumerate(weights):
861            if not isinstance(w, numbers.Number):
862                raise TypeError("type of weights element must be number, "
863                                "but got w[{}]: {}, type: {}.".format(ind, w, type(w)))
864
865        if num_samples is not None:
866            if not isinstance(num_samples, int):
867                raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
868            if num_samples < 0 or num_samples > validator.INT64_MAX:
869                raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
870                                 .format(0, validator.INT64_MAX))
871
872        if not isinstance(replacement, bool):
873            raise TypeError("replacement must be a boolean value but was: {}.".format(replacement))
874
875        self.weights = weights
876        self.replacement = replacement
877        super().__init__(num_samples)
878
879    def parse(self):
880        """ Parse the sampler."""
881        num_samples = self.num_samples if self.num_samples is not None else 0
882        replacement = self.replacement if self.replacement is not None else True
883        c_sampler = cde.WeightedRandomSamplerObj(self.weights, num_samples, replacement)
884        c_child_sampler = self.parse_child()
885        c_sampler.add_child(c_child_sampler)
886        return c_sampler
887
888    def is_shuffled(self):
889        return True
890
891    def is_sharded(self):
892        if self.child_sampler is None:
893            return False
894
895        return self.child_sampler.is_sharded()
896