• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Dataset help for minddata dataset"""
16from __future__ import absolute_import
17
18import math
19import copy
20
21from mindspore import _checkparam as Validator
22from mindspore import log as logger
23from mindspore.common._auto_dynamic import is_auto_dynamic, convert_new_shapes
24from mindspore.common.dtype import pytype_to_dtype
25from mindspore.common.api import _cell_graph_executor, _is_args_fullmode, ARG_SPECIFIED
26from mindspore.common._utils import is_shape_unknown
27from mindspore.dataset.engine import offload
28from mindspore import context, nn
29from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
30from mindspore.parallel._utils import _get_device_num, _get_global_rank, _need_to_full, \
31    _to_full_shapes, _get_pipeline_stages, _change_symbols_for_parallel, _is_in_auto_parallel_mode, \
32    _origin_shapes, _dynamic_shape_for_dataset
33from mindspore.parallel._ps_context import _is_role_sched
34from mindspore.ops import operations as P
35from mindspore.common.auto_dynamic_shape import _auto_dynamic_shape
36
37
38def _send_data(dataset, epoch_num):
39    """Engine dataset to write data to tdt queue."""
40    if not hasattr(dataset, '__has_sent__'):
41        exec_dataset = dataset.__transfer_dataset__
42        exec_dataset.send(epoch_num)
43        dataset.__has_sent__ = True
44
45
46def _send_data_no_flag(dataset, epoch_num):
47    """Engine dataset to write data to tdt queue directly."""
48    exec_dataset = dataset.__transfer_dataset__
49    exec_dataset.send(epoch_num)
50
51
52def _dynamic_sink_data(dataset, dataset_iter):
53    """Special scenario for dataset with sink_size=1."""
54    if hasattr(dataset_iter, "sink_size") and \
55       dataset_iter.sink_size == 1 and \
56       dataset.get_dataset_size() != 1 and \
57       not hasattr(dataset, "__no_send__") and \
58       hasattr(dataset_iter, "sink_count") and \
59       dataset_iter.sink_count == 1:
60        return True
61    return False
62
63
64def _dynamic_sink_exception_scenario(dataset_iter, is_dynamic):
65    """The exception scenario for dynamic data is not applicable."""
66    if context.get_context("mode") != context.GRAPH_MODE or is_dynamic:
67        return True
68    return False
69
70
71def _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic):
72    """Special scenario with dynamic shape and sink_size=1."""
73    flag = False
74
75    # This is used only for test
76    if is_auto_dynamic():
77        return False
78
79    if _dynamic_sink_data(dataset, dataset_iter) and not _dynamic_sink_exception_scenario(dataset_iter, is_dynamic):
80        flag = True
81
82    return flag
83
84
85class _DataWrapper(nn.Cell):
86    """
87    Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
88    dataset channel 'queue_name' and performs the forward computation.
89    """
90
91    def __init__(self, network, dataset_types, dataset_shapes, queue_name):
92        super(_DataWrapper, self).__init__(
93            auto_prefix=False, flags=network.get_flags())
94        # Also copy the flag in `network` construct
95        flags = getattr(network.__class__.construct, "_func_graph_flags", {})
96        self.info = (dataset_types, dataset_shapes)
97        self.add_flags(**flags)
98        self.get_next = P.GetNext(
99            dataset_types, dataset_shapes, len(dataset_types), queue_name)
100        if network.get_inputs() is not None:
101            network_inputs = network.get_inputs()
102            is_fullmode = _is_args_fullmode(network_inputs, False)
103            if is_fullmode:
104                symbol_inputs = [getattr(inp, "symbolic_shape", None) for inp in network.get_inputs()]
105            else:
106                symbol_inputs = [None for _ in dataset_shapes]
107                arg_specified = network_inputs.get(ARG_SPECIFIED, [])
108                for idx, inp in arg_specified:
109                    symbol_inputs[idx] = getattr(inp, "symbolic_shape", None)
110            symbols_for_parallel = _change_symbols_for_parallel(dataset_shapes, copy.deepcopy(symbol_inputs))
111            if any((s is not None for s in symbols_for_parallel)):
112                self.get_next.add_prim_attr("symbols", symbol_inputs)
113                self.get_next.add_prim_attr("symbols_for_parallel", symbols_for_parallel)
114        self.network = network
115        self._get_attr_from_cell(network)
116
117    def construct(self):
118        outputs = self.get_next()
119        return self.network(*outputs)
120
121
122def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name):
123    if not isinstance(network, _DataWrapper):
124        network = _DataWrapper(
125            network, dataset_types, dataset_shapes, queue_name)
126    return network
127
128
129def _has_dynamic_shape(dataset_shapes):
130    for shape in dataset_shapes:
131        if is_shape_unknown(shape):
132            return True
133    return False
134
135
136def _generate_network_with_dataset(network, dataset_helper, queue_name):
137    """
138    Generate new network with network and dataset info.
139    """
140    dataset_types, dataset_shapes = dataset_helper.types_shapes()
141
142    # This is used only for test
143    if is_auto_dynamic():
144        new_shapes = convert_new_shapes(dataset_shapes)
145        return _generate_dataset_sink_mode_net(network, new_shapes, dataset_types, queue_name)
146
147    if network.get_inputs() and None not in network.get_inputs():
148        if _is_in_auto_parallel_mode():
149            # here, the dataset shapes has been processed by full_shape(), so need to resume it to original shape
150            # the _check_inputs() will change static origin_shape to dynamic shape
151            # after _check_inputs(), convert dataset_shapes to dynamic shape
152            origin_shape = _origin_shapes(dataset_shapes)
153            _check_inputs(network.get_inputs(), origin_shape, dataset_types)
154            dataset_shapes = _dynamic_shape_for_dataset(dataset_shapes, origin_shape)
155        else:
156            _check_inputs(network.get_inputs(), dataset_shapes, dataset_types)
157    elif context.get_context("mode") == context.PYNATIVE_MODE:
158        dataset_shapes = tuple([(-2,)] * len(dataset_shapes))
159    network = _generate_dataset_sink_mode_net(
160        network, dataset_shapes, dataset_types, queue_name)
161    return network
162
163
164def _check_inputs(network_shapes, dataset_shapes, dataset_types):
165    """
166    Check if set inputs are correct.
167    """
168    if not _is_args_fullmode(network_shapes, False):
169        temp_network_shapes = [None for _ in dataset_shapes]
170        arg_specified = network_shapes.get(ARG_SPECIFIED, [])
171        for idx, inp in arg_specified:
172            temp_network_shapes[idx] = inp
173        network_shapes = temp_network_shapes
174
175    for tensor_index, ele_dataset_shape in enumerate(dataset_shapes):
176        if network_shapes[tensor_index] is None:
177            continue
178        set_inputs_shape = list(network_shapes[tensor_index].shape)
179        inputs_shape = list(ele_dataset_shape)
180        if dataset_types[tensor_index] != network_shapes[tensor_index].dtype:
181            raise TypeError(
182                f"The {tensor_index+1}th input type of 'set_inputs' must be the same as network's input, "
183                f"but got 'set_inputs': {network_shapes[tensor_index].dtype} and network's "
184                f"input: {dataset_types[tensor_index]}."
185            )
186        if len(inputs_shape) != len(set_inputs_shape):
187            raise ValueError(
188                f"The {tensor_index + 1}th input dims of 'set_inputs' must be the same as network's input, "
189                f"but got 'set_inputs': {len(set_inputs_shape)} and network's input: {len(inputs_shape)}.")
190        for index, ele_shape in enumerate(ele_dataset_shape):
191            if network_shapes[tensor_index].shape[index] != -1:
192                if set_inputs_shape[index] != ele_shape:
193                    raise ValueError(
194                        f"The {index + 1}th input shape of 'set_inputs' must be the same as network's input, "
195                        f"but got 'set_inputs': {set_inputs_shape[index]} and network's input: "
196                        f"{dataset_shapes[tensor_index][index]}.")
197            else:
198                dataset_shapes[tensor_index][index] = -1
199
200
201class _DatasetAux:
202    @staticmethod
203    def __deepcopy__(memodict):
204        return
205
206
207def _get_dataset_aux(dataset):
208    if not hasattr(dataset, '__network_aux__'):
209        dataset.__network_aux__ = _DatasetAux()
210    return dataset.__network_aux__
211
212
213def connect_network_with_dataset(network, dataset_helper):
214    """
215    Connect the `network` with dataset in `dataset_helper`. Only supported in `sink mode
216    <https://mindspore.cn/tutorials/experts/en/master/optimize/execution_opt.html>`_, (dataset_sink_mode=True).
217
218    Args:
219        network (Cell): The training network for dataset.
220        dataset_helper (DatasetHelper): A class to process the MindData dataset, it provides the type, shape and queue
221            name of the dataset.
222
223    Returns:
224        Cell, a new network containing the type, shape and queue name of the dataset info.
225
226    Raises:
227        RuntimeError: If the API was not called in dataset sink mode.
228
229    Supported Platforms:
230        ``Ascend`` ``GPU``
231
232    Examples:
233        >>> import numpy as np
234        >>> import mindspore as ms
235        >>> from mindspore import nn
236        >>> from mindspore import dataset as ds
237        >>>
238        >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
239        >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
240        >>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True)
241        >>> net = nn.Dense(10, 5)
242        >>> net_with_dataset = ms.connect_network_with_dataset(net, dataset_helper)
243    """
244    dataset_iter = dataset_helper.iter
245    dataset = dataset_iter.dataset
246    aux = _get_dataset_aux(dataset)
247
248    if isinstance(dataset_iter, _DatasetIterNormal):
249        raise RuntimeError(
250            "The API 'connect_network_with_dataset' should be called in dataset sink mode.")
251
252    if _is_role_sched():
253        network.add_flags(sink_mode=True)
254        return network
255
256    if not hasattr(aux, '__network__'):
257        aux.__network__ = network
258
259    if aux.__network__ is not network:
260        raise ValueError(
261            "The dataset has been connected to other network, please check the code.")
262    is_dynamic = bool(network.get_inputs())
263    queue_name = dataset.__transfer_dataset__.queue_name
264    if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic):
265        dataset_types, dataset_shapes = dataset_helper.get_data_info()
266        # Need to do full_batch for shapes which also do in the _DatasetIterMSLoopSink
267        if _need_to_full():
268            dataset_shapes = _to_full_shapes(dataset_shapes, _get_device_num() // _get_pipeline_stages())
269        dataset_types = [pytype_to_dtype(x) for x in dataset_types]
270        if not is_dynamic:
271            dataset_shapes = _auto_dynamic_shape.auto_dynamic_generate_compile_args(dataset_shapes, True)
272        key = str(dataset_types) + str(dataset_shapes)
273
274        if hasattr(aux, "__shape_type__") and aux.__shape_type__ != key:
275            _auto_dynamic_shape.update_phase_and_compile_args(dataset_shapes, key, True, aux)
276            if hasattr(aux, '__network_manage__') and key in aux.__network_manage__:
277                network = aux.__network_manage__[key]
278            else:
279                if _need_to_full():
280                    device_num = _get_device_num() // _get_pipeline_stages()
281                    dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
282
283                network = _generate_dataset_sink_mode_net(
284                    network, dataset_shapes, dataset_types, queue_name)
285                if hasattr(aux, '__network_manage__'):
286                    aux.__network_manage__ = aux.__network_manage__
287                else:
288                    aux.__network_manage__ = dict()
289                aux.__network_manage__[key] = network
290            network.add_flags(sink_mode=True)
291            return network
292
293    if hasattr(aux, '__sink_network__'):
294        network = aux.__sink_network__
295    else:
296        if context.get_context("device_target") in ("Ascend", "GPU"):
297            network = offload.check_add_offload_sink_mode(
298                dataset, dataset_helper, network)
299            network = _generate_network_with_dataset(
300                network, dataset_helper, queue_name)
301            aux.__sink_network__ = network
302            dataset_types, dataset_shapes = dataset_helper.types_shapes()
303            aux.__shape_type__ = str(dataset_types) + str(dataset_shapes)
304
305    if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter, is_dynamic):
306        dataset_helper.get_data_info()
307    network.add_flags(sink_mode=True)
308    return network
309
310
311class DatasetHelper:
312    """
313    DatasetHelper is a class to process the MindData dataset and provides the information of dataset.
314
315    According to different contexts, change the iterations of dataset and use the same iteration for loop in different
316    contexts.
317
318    Note:
319        The iteration of DatasetHelper will provide one epoch data.
320
321    Args:
322        dataset (Dataset): The dataset iterator. The dataset can be generated by dataset generator API in
323                           `mindspore.dataset` module, such as :class:`mindspore.dataset.ImageFolderDataset`.
324        dataset_sink_mode (bool): If the value is True, GetNext is employed to fetch the data at device through the
325                                  dataset pipeline, otherwise fetch the data at host by iterating through the dataset.
326                                  Default: ``True``.
327        sink_size (int): Control the amount of data in each sink.
328                          If sink_size=-1, sink the complete dataset for each epoch.
329                          If sink_size>0, sink sink_size data for each epoch.
330                          Default: -1.
331        epoch_num (int): The number of passes of the entire dataset to be sent. Default: 1.
332
333    Examples:
334        >>> import numpy as np
335        >>> import mindspore as ms
336        >>> from mindspore import nn
337        >>> from mindspore import dataset as ds
338        >>>
339        >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
340        >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
341        >>> set_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=False)
342        >>>
343        >>> net = nn.Dense(10, 5)
344        >>> # Object of DatasetHelper is iterable
345        >>> for next_element in set_helper:
346        ...     # `next_element` includes data and label, using data to run the net
347        ...     data = next_element[0]
348        ...     result = net(data)
349    """
350
351    def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):
352        dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
353        Validator.check_is_int(sink_size)
354        if sink_size < -1 or sink_size == 0:
355            raise ValueError(
356                "The 'sink_size' must be -1 or positive, but got sink_size {}.".format(sink_size))
357        if sink_size == -1:
358            sink_size = dataset.get_dataset_size()
359
360        if dataset_sink_mode:
361            if context.get_context("mode") == context.GRAPH_MODE:
362                if _is_role_sched():
363                    iterclass = _DatasetIterPSServer
364                elif (context.get_context("device_target") == "Ascend") or \
365                        (context.get_context("device_target") == "GPU"):
366                    iterclass = _DatasetIterMSLoopSink
367                else:
368                    target = context.get_context("device_target")
369                    raise RuntimeError("Currently dataset sink mode is not supported when the device "
370                                       "target is {}, please set dataset_sink_mode to False "
371                                       "in Model.train()".format(target))
372            else:
373                iterclass = _DatasetIterPyNative
374            self.iter = iterclass(dataset, sink_size, epoch_num)
375        else:
376            iterclass = _DatasetIterNormal
377            self.iter = iterclass(dataset, epoch_num=epoch_num)
378
379    def __iter__(self):
380        return self.iter.__iter__()
381
382    # A temp solution for loop sink. Delete later
383    def types_shapes(self):
384        """
385        Get the types and shapes from dataset on the current configuration.
386
387        Examples:
388            >>> import mindspore as ms
389            >>> import numpy as np
390            >>>
391            >>> # Define a dataset pipeline
392            >>> def generator():
393            ...    for i in range(5):
394            ...        yield (np.ones((32, 10)),)
395            >>>
396            >>> train_dataset = ms.dataset.GeneratorDataset(generator, ["data"])
397            >>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True)
398            >>>
399            >>> types, shapes = dataset_helper.types_shapes()
400        """
401        return self.iter.types_shapes()
402
403    def sink_size(self):
404        """
405        Get sink_size for each iteration.
406
407        Examples:
408            >>> import mindspore as ms
409            >>> import numpy as np
410            >>>
411            >>> # Define a dataset pipeline
412            >>> def generator():
413            ...    for i in range(5):
414            ...        yield (np.ones((32, 10)),)
415            >>>
416            >>> train_dataset = ms.dataset.GeneratorDataset(generator, ["data"])
417            >>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True, sink_size=-1)
418            >>>
419            >>> # if sink_size==-1, then will return the full size of source dataset.
420            >>> sink_size = dataset_helper.sink_size()
421        """
422        return self.iter.get_sink_size()
423
424    def stop_send(self):
425        """
426        Stop send data about data sink.
427
428        Examples:
429            >>> import mindspore as ms
430            >>> import numpy as np
431            >>> # Define a dataset pipeline
432            >>> def generator():
433            ...    for i in range(5):
434            ...        yield (np.ones((32, 10)),)
435            >>> train_dataset = ms.dataset.GeneratorDataset(generator, ["data"])
436            >>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True, sink_size=-1)
437            >>> dataset_helper.stop_send()
438        """
439        self.iter.stop_send()
440
441    def release(self):
442        """
443        Free up resources about data sink.
444
445        Examples:
446            >>> import numpy as np
447            >>> import mindspore as ms
448            >>> from mindspore import nn
449            >>> from mindspore import dataset as ds
450            >>>
451            >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
452            >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
453            >>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True)
454            >>> dataset_helper.release()
455        """
456        self.iter.release()
457
458    def continue_send(self):
459        """
460        Continue to send data to device at the beginning of epoch.
461
462        Examples:
463            >>> import numpy as np
464            >>> import mindspore as ms
465            >>> from mindspore import nn
466            >>> from mindspore import dataset as ds
467            >>>
468            >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
469            >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
470            >>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True)
471            >>> dataset_helper.continue_send()
472        """
473        self.iter.continue_send()
474
475    def _reset(self, step, dataset_size):
476        """Reset the dataset to the provided step and epoch."""
477        self.iter._reset(step, dataset_size)  # pylint: disable=protected-access
478
479    # pylint: disable=missing-docstring
480    def get_data_info(self):
481        # In sink mode, it returns the types and shapes of the current data.
482        # Generally, it works in dynamic shape scenarios.
483        return self.iter.get_data_info()
484
485    # pylint: disable=missing-docstring
486    def get_mbuf_queue_size(self):
487        # In sink mode, it returns the element numbers inside mbuf channel.
488        return self.iter.get_mbuf_queue_size()
489
490    # pylint: disable=missing-docstring
491    def get_send_info(self, run_context):
492        # In sink mode, it returns the send information of dataset at this moment.
493        # Send information includes number of send batches, time summary of fetching data on host
494        # and time summary of sending data.
495        class InfoViewer:
496            '''
497            Inner class for parsing send info.
498            '''
499
500            def __init__(self, send_info, run_context):
501                self.info_ = {}
502                self.sink_size = run_context.original_args()["batch_num"]
503                if run_context.original_args().get("train_dataset", None) is not None:
504                    self.dataset_size = run_context.original_args()["train_dataset"].get_dataset_size()
505                elif run_context.original_args().get("valid_dataset", None) is not None:
506                    self.dataset_size = run_context.original_args()["valid_dataset"].get_dataset_size()
507                else:
508                    raise RuntimeError("Could not find a proper dataset to estimate dataset size.")
509                if not send_info:
510                    epoch = 1
511                    self.info_[epoch] = {'fetch_data_num': 0, 'fetch_data_time': 0, 'first_data_time': 0}
512                else:
513                    for info_per_epoch in send_info:
514                        epoch, fetch_data_num, first_data_time, fetch_data_time = info_per_epoch
515                        if fetch_data_num > 1:
516                            fetch_data_time = (fetch_data_time - first_data_time) / (fetch_data_num - 1) * 1000.
517                        self.info_[epoch] = {'fetch_data_num': fetch_data_num,
518                                             'fetch_data_time': fetch_data_time,
519                                             'first_data_time': first_data_time}
520
521            def epoch(self, epoch):
522                if self.sink_size == self.dataset_size:
523                    return self.info_[epoch]
524                global_step = epoch * self.sink_size
525                data_epoch = math.ceil(global_step / self.dataset_size)
526                return self.info_[data_epoch]
527
528        # send info struct:[epoch, data_num_per_epoch, first_data_time, accumulate_data_time]
529        # for example [1, 1875, 0.421, 0.362]
530        send_info = self.iter.get_send_info()
531        return InfoViewer(send_info, run_context)
532
533
534class _DatasetIter:
535    """Base iter for dataset helper"""
536
537    def __init__(self, dataset, sink_size, epoch_num):
538        self.dataset = dataset
539        self.sink_size = sink_size
540        self.sink_count = self.get_sink_count(dataset)
541        self.dataset_types, self.dataset_shapes = _get_types_and_shapes(
542            dataset)
543
544        if dataset.get_init_step() % sink_size != 0:
545            init_epoch = dataset.get_init_step() // sink_size
546            init_step = init_epoch * sink_size
547            logger.warning("Init global step must be the end of the epoch in sink mode, "
548                           "but got: {0}. Reset it to the end of epoch {1} at step {2}."
549                           .format(dataset.get_init_step(), init_epoch, init_step))
550            dataset.set_init_step(init_step)
551
552        if not hasattr(dataset, '__transfer_dataset__'):
553            if hasattr(dataset, '__loop_size__'):
554                self.sink_size = dataset.__loop_size__
555            create_data_info_queue = (
556                sink_size == 1 and self.sink_count == 1 and dataset.get_dataset_size() != 1)
557            dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size,
558                                                           create_data_info_queue=create_data_info_queue)
559
560            if not hasattr(dataset, '__no_send__'):
561                _send_data(dataset, epoch_num)
562        else:
563            # if using an existed __transfer_dataset__, set the queue_name directly
564            if not dataset.__transfer_dataset__.queue_name:
565                _cell_graph_executor.set_queue_name(
566                    dataset.__transfer_dataset__.queue_name)
567            _send_data_no_flag(dataset, epoch_num)
568
569        self.stop_send = dataset.__transfer_dataset__.stop_send
570        self.release = dataset.__transfer_dataset__.release
571        self.continue_send = dataset.__transfer_dataset__.continue_send
572        self.get_data_info = dataset.__transfer_dataset__.get_data_info
573        self.get_mbuf_queue_size = dataset.__transfer_dataset__.get_mbuf_queue_size
574        self.get_send_info = dataset.__transfer_dataset__.get_send_info
575        if hasattr(dataset.__transfer_dataset__, "_reset"):
576            self._reset = dataset.__transfer_dataset__._reset  # pylint: disable=protected-access
577
578    def __iter__(self):
579        self.index = 0
580        return self
581
582    def __next__(self):
583        if self.index >= self.sink_count:
584            raise StopIteration()
585        self.index += 1
586        return self.op()
587
588    def types_shapes(self):
589        """
590        Return the types and shapes of the dataset. The type and shape of each data in the dataset
591        should be consistent.
592        """
593        return self.dataset_types, self.dataset_shapes
594
595    def get_sink_count(self, dataset):
596        sink_count = 1
597        if hasattr(dataset, '__loop_size__'):
598            loop_size = dataset.__loop_size__
599            if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
600                raise ValueError(f"Dataset size {dataset.get_dataset_size()} and 'sink_size' {loop_size} "
601                                 f"are not matched, dataset size should be divisible by 'sink_size'.")
602            sink_count = math.ceil(dataset.get_dataset_size() / loop_size)
603        return sink_count
604
605    def get_sink_size(self):
606        """get sink_size to device"""
607        sink_size = 1
608        if hasattr(self.dataset, '__loop_size__'):
609            sink_size = self.dataset.__loop_size__
610        else:
611            if context.get_context("device_target") == "Ascend" or context.get_context("device_target") == "GPU":
612                if self.sink_size > 0:
613                    sink_size = self.sink_size
614                else:
615                    sink_size = self.dataset.get_dataset_size()
616        return sink_size
617
618
619class _DatasetIterPyNative(_DatasetIter):
620    """Iter for context (mode=PYNATIVE_MODE)."""
621
622    def __init__(self, dataset, sink_size, epoch_num):
623        super().__init__(dataset, sink_size, epoch_num)
624        if sink_size > 0:
625            self.sink_count = sink_size
626        else:
627            self.sink_count = dataset.get_dataset_size()
628
629        def op():
630            return tuple()
631
632        self.op = op
633
634
635class _DatasetIterMSLoopSink(_DatasetIter):
636    """Iter for context (device_target=Ascend)"""
637
638    def __init__(self, dataset, sink_size, epoch_num):
639        super().__init__(dataset, sink_size, epoch_num)
640        self.sink_count = self.get_sink_count(dataset)
641        # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
642        # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
643        # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
644        if _need_to_full():
645            device_num = _get_device_num() // _get_pipeline_stages()
646            self.dataset_shapes = _to_full_shapes(
647                self.dataset_shapes, device_num)
648
649        def op():
650            return tuple()
651
652        self.op = op
653
654
655class _DatasetIterPSServer(_DatasetIter):
656    """Iter for context on MS_PSERVER or MS_SCHED"""
657
658    def __init__(self, dataset, sink_size, epoch_num):
659        super().__init__(dataset, sink_size, epoch_num)
660        self.sink_count = 1
661        self.sink_size = 1
662        self.op = None
663
664        def op():
665            return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1)
666
667        self.op = op
668
669
670class _DatasetIterNormal:
671    """Iter for normal(non sink) mode, feed the data from host."""
672
673    def __init__(self, dataset, epoch_num=-1):
674        self.dataset = dataset
675        self.device_num = _get_device_num()
676        self.global_rank = _get_global_rank()
677        self.iter = self.dataset.create_tuple_iterator(
678            num_epochs=epoch_num, do_copy=True)
679
680    def __iter__(self):
681        return self
682
683    def __next__(self):
684        data = self.iter.__next__()
685        return data
686
687
688__all__ = ["DatasetHelper", "connect_network_with_dataset"]
689