• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019-2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Built-in iterators"""
16from abc import abstractmethod
17from copy import deepcopy
18import json
19import os
20import signal
21import weakref
22import numpy as np
23
24import mindspore._c_dataengine as cde
25from mindspore.common.tensor import Tensor, np_types
26import mindspore.dataset.engine.offload as offload
27from mindspore.dataset.core.config import get_debug_mode
28
29from mindspore import log as logger
30
31_ITERATOR_CLEANUP = False
32
33
34def _set_iterator_cleanup():
35    global _ITERATOR_CLEANUP
36    _ITERATOR_CLEANUP = True
37
38
39def _unset_iterator_cleanup():
40    global _ITERATOR_CLEANUP
41    _ITERATOR_CLEANUP = False
42
43
44def check_iterator_cleanup():
45    global _ITERATOR_CLEANUP
46    return _ITERATOR_CLEANUP
47
48
49ITERATORS_LIST = list()
50
51
52def _cleanup():
53    """Release all the Iterator."""
54    _set_iterator_cleanup()
55    for itr_ref in reversed(ITERATORS_LIST):
56        itr = itr_ref()
57        if itr is not None:
58            itr.release()
59
60
61class Iterator:
62    """
63    General Iterator over a dataset.
64
65    Attributes:
66        dataset: Dataset to be iterated over
67    """
68
69    def __init__(self, dataset, num_epochs=-1, output_numpy=False, do_copy=True):
70        self._col_names = None
71
72        # create a copy of tree and work on it.
73        self.__ori_dataset = dataset
74
75        self.ir_tree, self.dataset = dataset.create_ir_tree()
76
77        self._runtime_context = cde.PythonRuntimeContext()
78        self._runtime_context.Init()
79        if dataset.get_init_step() == 0:
80            init_step = 0
81            dataset_size = -1
82        else:
83            init_step = dataset.get_init_step()
84            dataset_size = dataset.get_dataset_size()
85        if get_debug_mode():
86            consumer = cde.PythonPullBasedIteratorConsumer(num_epochs)
87            consumer.Init(self.ir_tree)
88        else:
89            consumer = cde.PythonIteratorConsumer(num_epochs)
90            consumer.Init(self.ir_tree, init_step, dataset_size)
91        self._runtime_context.AssignConsumer(consumer)
92        self._iterator = self._runtime_context.GetConsumer()
93        self._output_numpy = output_numpy
94        self._do_copy = do_copy
95
96        self.__index = 0
97
98        self.offload_model = None
99        json_offload = json.loads(consumer.GetOffload())
100
101        # See if GetOffload identified any operations set to be offloaded.
102        if json_offload is not None:
103            offload.check_concat_zip_dataset(self.__ori_dataset)
104            self.offload_model = offload.GetOffloadModel(consumer, self.__ori_dataset.get_col_names())
105
106        ITERATORS_LIST.append(weakref.ref(self))
107        _unset_iterator_cleanup()
108
109    def __iter__(self):
110        return self
111
112    def stop(self):
113        """
114        Manually terminate Python iterator instead of relying on out of scope destruction.
115        """
116        if hasattr(self, '_runtime_context') and self._runtime_context:
117            if hasattr(self, '_iterator') and self._iterator:
118                self._runtime_context.Terminate()
119                del self._iterator
120            del self._runtime_context
121            del self.dataset
122
123            # get weakref which is dead
124            dead_iterator = []
125            for index, item in enumerate(ITERATORS_LIST):
126                # item() == None indicate the object is dead
127                # id(item()) == id(self) indicate del self
128                if item() is None or id(item()) == id(self):
129                    dead_iterator.append(index)
130
131            # del dead weakref
132            for index in reversed(dead_iterator):
133                ITERATORS_LIST.pop(index)
134
135    def release(self):
136        self.stop()
137
138    def __del__(self):
139        self.release()
140
141    @abstractmethod
142    def _get_next(self):
143        raise RuntimeError("Calling base class Iterator's get_next is invalid.")
144
145    def __next__(self):
146        if not self._runtime_context:
147            logger.warning("Iterator does not have a running C++ pipeline." +
148                           "It might because Iterator stop() had been called, or C++ pipeline crashed silently.")
149            raise RuntimeError("Iterator does not have a running C++ pipeline.")
150
151        # Note offload is applied inside _get_next() if applicable since get_next converts to output format
152        data = self._get_next()
153        if not data:
154            if self.__index == 0:
155                logger.warning("No records available.")
156            if self.__ori_dataset.dataset_size is None:
157                self.__ori_dataset.dataset_size = self.__index
158            raise StopIteration
159        self.__index += 1
160
161        return data
162
163    def __deepcopy__(self, memo):
164        return self
165
166    def _getters(self):
167        """
168        Get pipeline information.
169        """
170        getter = cde.TreeGetters()
171        getter.Init(self.ir_tree)
172        self._runtime_context.AssignConsumer(getter)
173        self._col_names = getter.GetColumnNames()
174
175    def get_col_names(self):
176        """
177        Get names of the columns in the dataset
178        """
179        if self._col_names is None:
180            self._col_names = self.__ori_dataset.get_col_names()
181        return self._col_names
182
183    def _reset(self, step, dataset_size):
184        """
185        Reset the iterator to the given step number and epoch number.
186
187        Args:
188            step (int): Global step number
189            dataset_size (int): The number of steps that one epoch has.
190        """
191        self._iterator.Reset(step, dataset_size)
192
193    def __convert_python(self, obj, to_numpy):
194        """
195        Attempts to recursively convert a python object to Numpy array(s) or tensor(s).
196
197        Args:
198            obj (any): the python object to be converted
199            to_numpy (bool): If True, convert primitive types to NumPy array. If False, convert to Tensor.
200                             (return the obj if type isn't supported)
201        """
202        if isinstance(obj, (int, float, bool, str, np.ndarray, np.str_, np.bytes_, *np_types)):
203            # error out if array is of unsupported type
204            if isinstance(obj, np.ndarray) and obj.dtype not in np_types and obj.dtype.kind not in ('U', 'S'):
205                new_line = '\n'
206                raise TypeError("A NumPy array of unsupported type detected: {}."
207                                "\nSupported types are: {}.".format(
208                                    obj.dtype, new_line.join(map(str, (*np_types, np.str_, np.bytes_)))))
209            if to_numpy:
210                return np.array(obj, copy=self._do_copy)
211            if self._do_copy:
212                return Tensor(np.asarray(obj))
213            return Tensor.from_numpy(np.asarray(obj))
214        if isinstance(obj, dict):
215            return {key: self.__convert_python(val, to_numpy) for key, val in obj.items()}
216        if isinstance(obj, tuple):
217            return tuple([self.__convert_python(item, to_numpy) for item in obj])
218        if isinstance(obj, list):
219            return [self.__convert_python(item, to_numpy) for item in obj]
220        # if we can't convert it to Tensor, return the object as is
221        if self._do_copy:
222            return deepcopy(obj)
223        return obj
224
225    def _transform_md_to_output(self, t):
226        if self._output_numpy:
227            if t.type().is_python():
228                return self.__convert_python(t.as_python(), True)
229            return t.as_array()
230        return self._transform_md_to_tensor(t)
231
232    def _transform_md_to_tensor(self, t):
233        if t.type().is_python():
234            return self.__convert_python(t.as_python(), False)
235        array = t.as_array()
236        if self._do_copy:
237            return Tensor(array)
238        return Tensor.from_numpy(array)
239
240    def _transform_tensor_to_output(self, t):
241        if self._output_numpy:
242            return t.asnumpy()
243        return t
244
245
246class DictIterator(Iterator):
247    """
248    The derived class of Iterator with dict type.
249    """
250
251    def _get_next(self):
252        """
253        Returns the next record in the dataset as dictionary
254
255        Returns:
256            Dict, the next record in the dataset.
257        """
258        try:
259            if self.offload_model is None:
260                return {k: self._transform_md_to_output(t) for k, t in self._iterator.GetNextAsMap().items()}
261            data = [self._transform_md_to_tensor(t) for t in self._iterator.GetNextAsList()]
262            if data:
263                data = offload.apply_offload_iterators(data, self.offload_model)
264                # Create output dictionary after offload
265                out_data = {}
266                for i, col in enumerate(self.get_col_names()):
267                    out_data[col] = self._transform_tensor_to_output(data[i])
268                data = out_data
269            return data
270
271        except RuntimeError as err:
272            # maybe "Out of memory" / "MemoryError" error
273            err_info = str(err)
274            if err_info.find("Out of memory") >= 0 or err_info.find("MemoryError") >= 0:
275                logger.critical("Memory error occurred, process will exit.")
276                os.kill(os.getpid(), signal.SIGKILL)
277            raise err
278
279
280class TupleIterator(Iterator):
281    """
282    The derived class of Iterator with list type.
283    """
284
285    def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False, do_copy=True):
286        if columns is not None:
287            if not isinstance(columns, list):
288                columns = [columns]
289            dataset = dataset.project(columns)
290        super().__init__(dataset, num_epochs, output_numpy, do_copy)
291
292    def _get_next(self):
293        """
294        Returns the next record in the dataset as a list
295
296        Returns:
297            List, the next record in the dataset.
298        """
299
300        if self.offload_model is None:
301            return [self._transform_md_to_output(t) for t in self._iterator.GetNextAsList()]
302        data = [self._transform_md_to_tensor(t) for t in self._iterator.GetNextAsList()]
303        if data:
304            data = offload.apply_offload_iterators(data, self.offload_model)
305        return [self._transform_tensor_to_output(t) for t in data]
306
307
308class DummyIterator:
309    """
310    A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED"
311    """
312
313    def __init__(self, dataset, mode, output_numpy=False):
314        self.mode = mode
315        self.shapes = dataset.output_shapes()
316        self.types = dataset.output_types()
317        self.col_names = dataset.get_col_names()
318        self.fetched_first = False
319        self.output_numpy = output_numpy
320
321    def __get_tensor(self):
322        """Get a next tensor."""
323        tensor_row = []
324        for np_shape, np_type in zip(self.shapes, self.types):
325            input_np = np.zeros(np_shape, np_type)
326            tensor = Tensor(input_np)
327            if self.output_numpy:
328                tensor_row.append(tensor.asnumpy())
329            else:
330                tensor_row.append(tensor)
331        if self.mode == "dict":
332            tensor_row = {col_name: tensor for col_name, tensor in zip(self.col_names, tensor_row)}
333        return tensor_row
334
335    def __iter__(self):
336        return self
337
338    def __next__(self):
339        if not self.fetched_first:
340            self.fetched_first = True
341            return self.__get_tensor()
342        raise StopIteration()
343