• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Built-in iterators.
16"""
17from abc import abstractmethod
18import os
19import signal
20import weakref
21import numpy as np
22
23from mindspore.common.tensor import Tensor
24import mindspore._c_dataengine as cde
25
26from mindspore import log as logger
27
28_ITERATOR_CLEANUP = False
29
30
31def _set_iterator_cleanup():
32    global _ITERATOR_CLEANUP
33    _ITERATOR_CLEANUP = True
34
35
36def _unset_iterator_cleanup():
37    global _ITERATOR_CLEANUP
38    _ITERATOR_CLEANUP = False
39
40
41def check_iterator_cleanup():
42    global _ITERATOR_CLEANUP
43    return _ITERATOR_CLEANUP
44
45
46ITERATORS_LIST = list()
47
48
49def _cleanup():
50    """Release all the Iterator."""
51    _set_iterator_cleanup()
52    for itr_ref in reversed(ITERATORS_LIST):
53        itr = itr_ref()
54        if itr is not None:
55            itr.release()
56
57
58class Iterator:
59    """
60    General Iterator over a dataset.
61
62    Attributes:
63        dataset: Dataset to be iterated over
64    """
65
66    def __init__(self, dataset, num_epochs=-1, output_numpy=False, do_copy=True):
67        self._col_names = None
68
69        # create a copy of tree and work on it.
70        self.__ori_dataset = dataset
71
72        self.ir_tree, self.dataset = dataset.create_ir_tree()
73
74        self._runtime_context = cde.PythonRuntimeContext()
75        self._runtime_context.Init()
76        consumer = cde.PythonIteratorConsumer(num_epochs)
77        consumer.Init(self.ir_tree)
78        self._runtime_context.AssignConsumer(consumer)
79        self._iterator = self._runtime_context.GetConsumer()
80
81        self._transform_tensor = lambda t: t.as_array()
82        if not output_numpy:
83            if do_copy:
84                self._transform_tensor = lambda t: Tensor(t.as_array())
85            else:
86                self._transform_tensor = lambda t: Tensor.from_numpy(t.as_array())
87        self.__index = 0
88
89        ITERATORS_LIST.append(weakref.ref(self))
90        _unset_iterator_cleanup()
91
92    def __iter__(self):
93        return self
94
95    def stop(self):
96        """
97        Manually terminate Python iterator instead of relying on out of scope destruction.
98        """
99        if hasattr(self, '_runtime_context') and self._runtime_context:
100            if hasattr(self, '_iterator') and self._iterator:
101                self._runtime_context.Terminate()
102                del self._iterator
103            del self._runtime_context
104            del self.dataset
105
106            # get weakref which is dead
107            dead_iterator = []
108            for index, item in enumerate(ITERATORS_LIST):
109                # item() == None indicate the object is dead
110                # id(item()) == id(self) indicate del self
111                if item() is None or id(item()) == id(self):
112                    dead_iterator.append(index)
113
114            # del dead weakref
115            for index in reversed(dead_iterator):
116                ITERATORS_LIST.pop(index)
117
118    def release(self):
119        self.stop()
120
121    def __del__(self):
122        self.release()
123
124    @abstractmethod
125    def _get_next(self):
126        raise RuntimeError("Calling base class Iterator's get_next is invalid.")
127
128    def __next__(self):
129        if not self._runtime_context:
130            logger.warning("Iterator does not have a running C++ pipeline." +
131                           "It might because Iterator stop() had been called, or C++ pipeline crashed silently.")
132            raise RuntimeError("Iterator does not have a running C++ pipeline.")
133
134        data = self._get_next()
135        if not data:
136            if self.__index == 0:
137                logger.warning("No records available.")
138            if self.__ori_dataset.dataset_size is None:
139                self.__ori_dataset.dataset_size = self.__index
140            raise StopIteration
141        self.__index += 1
142        return data
143
144    def __deepcopy__(self, memo):
145        return self
146
147    def _getters(self):
148        """
149        Get pipeline information.
150        """
151        getter = cde.TreeGetters()
152        getter.Init(self.ir_tree)
153        self._runtime_context.AssignConsumer(getter)
154        self._col_names = getter.GetColumnNames()
155
156    def get_col_names(self):
157        """
158        Get names of the columns in the dataset
159        """
160        if self._col_names is None:
161            self._getters()
162        return self._col_names
163
164
165class DictIterator(Iterator):
166    """
167    The derived class of Iterator with dict type.
168    """
169
170    def _get_next(self):
171        """
172        Returns the next record in the dataset as dictionary
173
174        Returns:
175            Dict, the next record in the dataset.
176        """
177        try:
178            return {k: self._transform_tensor(t) for k, t in self._iterator.GetNextAsMap().items()}
179        except RuntimeError as err:
180            ## maybe "Out of memory" / "MemoryError" error
181            err_info = str(err)
182            if err_info.find("Out of memory") >= 0 or err_info.find("MemoryError") >= 0:
183                logger.error("Memory error occurred, process will exit.")
184                os.kill(os.getpid(), signal.SIGKILL)
185            raise err
186
187
188class TupleIterator(Iterator):
189    """
190    The derived class of Iterator with list type.
191    """
192
193    def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False, do_copy=True):
194        if columns is not None:
195            if not isinstance(columns, list):
196                columns = [columns]
197            dataset = dataset.project(columns)
198        super().__init__(dataset, num_epochs, output_numpy, do_copy)
199
200    def _get_next(self):
201        """
202        Returns the next record in the dataset as a list
203
204        Returns:
205            List, the next record in the dataset.
206        """
207
208        return [self._transform_tensor(t) for t in self._iterator.GetNextAsList()]
209
210
211class DummyIterator:
212    """
213    A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED"
214    """
215
216    def __init__(self, dataset, mode):
217        self.mode = mode
218        self.shapes = dataset.output_shapes()
219        self.types = dataset.output_types()
220        self.fetched_first = False
221
222    def __get_tensor(self):
223        tensor_row = []
224        for np_shape, np_type in zip(self.shapes, self.types):
225            input_np = np.zeros(np_shape, np_type)
226            tensor = Tensor(input_np)
227            tensor_row.append(tensor)
228        return tensor_row
229
230    def __iter__(self):
231        return self
232
233    def __next__(self):
234        if self.mode == "tuple":
235            if not self.fetched_first:
236                self.fetched_first = True
237                return self.__get_tensor()
238        raise StopIteration()
239