1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Built-in iterators. 16""" 17from abc import abstractmethod 18import os 19import signal 20import weakref 21import numpy as np 22 23from mindspore.common.tensor import Tensor 24import mindspore._c_dataengine as cde 25 26from mindspore import log as logger 27 28_ITERATOR_CLEANUP = False 29 30 31def _set_iterator_cleanup(): 32 global _ITERATOR_CLEANUP 33 _ITERATOR_CLEANUP = True 34 35 36def _unset_iterator_cleanup(): 37 global _ITERATOR_CLEANUP 38 _ITERATOR_CLEANUP = False 39 40 41def check_iterator_cleanup(): 42 global _ITERATOR_CLEANUP 43 return _ITERATOR_CLEANUP 44 45 46ITERATORS_LIST = list() 47 48 49def _cleanup(): 50 """Release all the Iterator.""" 51 _set_iterator_cleanup() 52 for itr_ref in reversed(ITERATORS_LIST): 53 itr = itr_ref() 54 if itr is not None: 55 itr.release() 56 57 58class Iterator: 59 """ 60 General Iterator over a dataset. 61 62 Attributes: 63 dataset: Dataset to be iterated over 64 """ 65 66 def __init__(self, dataset, num_epochs=-1, output_numpy=False, do_copy=True): 67 self._col_names = None 68 69 # create a copy of tree and work on it. 70 self.__ori_dataset = dataset 71 72 self.ir_tree, self.dataset = dataset.create_ir_tree() 73 74 self._runtime_context = cde.PythonRuntimeContext() 75 self._runtime_context.Init() 76 consumer = cde.PythonIteratorConsumer(num_epochs) 77 consumer.Init(self.ir_tree) 78 self._runtime_context.AssignConsumer(consumer) 79 self._iterator = self._runtime_context.GetConsumer() 80 81 self._transform_tensor = lambda t: t.as_array() 82 if not output_numpy: 83 if do_copy: 84 self._transform_tensor = lambda t: Tensor(t.as_array()) 85 else: 86 self._transform_tensor = lambda t: Tensor.from_numpy(t.as_array()) 87 self.__index = 0 88 89 ITERATORS_LIST.append(weakref.ref(self)) 90 _unset_iterator_cleanup() 91 92 def __iter__(self): 93 return self 94 95 def stop(self): 96 """ 97 Manually terminate Python iterator instead of relying on out of scope destruction. 98 """ 99 if hasattr(self, '_runtime_context') and self._runtime_context: 100 if hasattr(self, '_iterator') and self._iterator: 101 self._runtime_context.Terminate() 102 del self._iterator 103 del self._runtime_context 104 del self.dataset 105 106 # get weakref which is dead 107 dead_iterator = [] 108 for index, item in enumerate(ITERATORS_LIST): 109 # item() == None indicate the object is dead 110 # id(item()) == id(self) indicate del self 111 if item() is None or id(item()) == id(self): 112 dead_iterator.append(index) 113 114 # del dead weakref 115 for index in reversed(dead_iterator): 116 ITERATORS_LIST.pop(index) 117 118 def release(self): 119 self.stop() 120 121 def __del__(self): 122 self.release() 123 124 @abstractmethod 125 def _get_next(self): 126 raise RuntimeError("Calling base class Iterator's get_next is invalid.") 127 128 def __next__(self): 129 if not self._runtime_context: 130 logger.warning("Iterator does not have a running C++ pipeline." + 131 "It might because Iterator stop() had been called, or C++ pipeline crashed silently.") 132 raise RuntimeError("Iterator does not have a running C++ pipeline.") 133 134 data = self._get_next() 135 if not data: 136 if self.__index == 0: 137 logger.warning("No records available.") 138 if self.__ori_dataset.dataset_size is None: 139 self.__ori_dataset.dataset_size = self.__index 140 raise StopIteration 141 self.__index += 1 142 return data 143 144 def __deepcopy__(self, memo): 145 return self 146 147 def _getters(self): 148 """ 149 Get pipeline information. 150 """ 151 getter = cde.TreeGetters() 152 getter.Init(self.ir_tree) 153 self._runtime_context.AssignConsumer(getter) 154 self._col_names = getter.GetColumnNames() 155 156 def get_col_names(self): 157 """ 158 Get names of the columns in the dataset 159 """ 160 if self._col_names is None: 161 self._getters() 162 return self._col_names 163 164 165class DictIterator(Iterator): 166 """ 167 The derived class of Iterator with dict type. 168 """ 169 170 def _get_next(self): 171 """ 172 Returns the next record in the dataset as dictionary 173 174 Returns: 175 Dict, the next record in the dataset. 176 """ 177 try: 178 return {k: self._transform_tensor(t) for k, t in self._iterator.GetNextAsMap().items()} 179 except RuntimeError as err: 180 ## maybe "Out of memory" / "MemoryError" error 181 err_info = str(err) 182 if err_info.find("Out of memory") >= 0 or err_info.find("MemoryError") >= 0: 183 logger.error("Memory error occurred, process will exit.") 184 os.kill(os.getpid(), signal.SIGKILL) 185 raise err 186 187 188class TupleIterator(Iterator): 189 """ 190 The derived class of Iterator with list type. 191 """ 192 193 def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False, do_copy=True): 194 if columns is not None: 195 if not isinstance(columns, list): 196 columns = [columns] 197 dataset = dataset.project(columns) 198 super().__init__(dataset, num_epochs, output_numpy, do_copy) 199 200 def _get_next(self): 201 """ 202 Returns the next record in the dataset as a list 203 204 Returns: 205 List, the next record in the dataset. 206 """ 207 208 return [self._transform_tensor(t) for t in self._iterator.GetNextAsList()] 209 210 211class DummyIterator: 212 """ 213 A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED" 214 """ 215 216 def __init__(self, dataset, mode): 217 self.mode = mode 218 self.shapes = dataset.output_shapes() 219 self.types = dataset.output_types() 220 self.fetched_first = False 221 222 def __get_tensor(self): 223 tensor_row = [] 224 for np_shape, np_type in zip(self.shapes, self.types): 225 input_np = np.zeros(np_shape, np_type) 226 tensor = Tensor(input_np) 227 tensor_row.append(tensor) 228 return tensor_row 229 230 def __iter__(self): 231 return self 232 233 def __next__(self): 234 if self.mode == "tuple": 235 if not self.fetched_first: 236 self.fetched_first = True 237 return self.__get_tensor() 238 raise StopIteration() 239