1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16This dataset module creates an internal queue class to more optimally pass data 17between multiple processes in Python. It has same API as multiprocessing.queue 18but it will pass large data through shared memory. 19""" 20 21import errno 22import multiprocessing 23import platform 24import queue 25import types 26 27import numpy as np 28 29from mindspore import log as logger 30import mindspore._c_dataengine as cde 31from ..transforms.py_transforms_util import ExceptionHandler 32 33 34class _SharedQueue(multiprocessing.queues.Queue): 35 """ 36 Class to implement a queue using shared memory for better performance. 37 Args: 38 size: Number of elements in the queue. 39 count: Shared variable to suppress log printing. 40 copy_out: Flag to indidcate whether an extra copy should be done before returning. If data will immediately be 41 copied before returning, then this can be set to False. 42 max_rowsize: Maximum size of any element in the Queue in MB. 43 """ 44 45 def __init__(self, size, count, copy_out=False, max_rowsize=6): 46 super().__init__(size, ctx=multiprocessing.get_context()) 47 48 self.copy_out = copy_out 49 50 # pipe can hold up to 65,636 bytes at a time 51 # there is less benefit for small data. To small data it can be slower as we need to pass 100 bytes of metadata 52 # and then access the shared memory. 53 self.min_shared_mem = 10000 54 self.data_immediate = 0 55 self.data_shared = 1 56 self.count = count 57 self.print_error = True 58 59 if platform.system().lower() != 'windows' and max_rowsize == -1: 60 self.dynamic_shm = True 61 else: 62 self.dynamic_shm = False 63 # change max_rowsize in MB into bytes 64 self.seg_size = max_rowsize * 1024 * 1024 65 self.shm_list = [] 66 self.seg_pos = 0 67 # num_seg has to be 2 more than the queue size. We can have remote worker filling a buffer, main process 68 # reading a buffer and also have a full queue of buffers in the meta-data queue 69 self.num_seg = size + 2 70 for _ in range(self.num_seg): 71 try: 72 a = multiprocessing.Array("b", self.seg_size) 73 except OSError as e: 74 if e.errno == errno.ENOMEM: 75 raise RuntimeError("Failed to allocate shared memory for {0} elements of {1}MB: {2}" 76 .format(self.num_seg, self.seg_size / 1024 / 1024, e)) 77 raise 78 else: 79 self.shm_list.append(a) 80 81 def put_until(self, data, timeout=None, exit_signal=None): 82 """Put data into the queue. Block until timeout is reached or exit_signal is set.""" 83 while True: 84 try: 85 self.put(data, timeout=timeout) 86 return 87 except queue.Full as e: 88 if exit_signal is None: 89 raise e 90 if exit_signal.is_set(): 91 return 92 continue 93 94 def put(self, data, timeout=None): 95 if isinstance(data, ExceptionHandler): # pylint: disable=too-many-nested-blocks 96 super().put(data, timeout=timeout) 97 else: 98 name_list = [] 99 start_bytes = 0 100 if not isinstance(data, tuple): 101 data = (data,) 102 if isinstance(data, np.ndarray): 103 name_list.append((self.data_immediate, np.array(data))) 104 else: 105 for r in data: 106 # the map:pyfunc is a yield generator which can't be serialize 107 if isinstance(r, types.GeneratorType): 108 raise TypeError("Cannot pickle {} object, please verify pyfunc return with numpy array" 109 .format(type(r))) 110 if isinstance(r, np.ndarray) and self.dynamic_shm: 111 byte = r.nbytes 112 shm = cde.SharedMemory(None, True, -1, byte) 113 dest = np.ndarray(r.shape, r.dtype, buffer=shm.buf()) 114 np.copyto(dest, r) 115 fd = shm.fd() 116 df = multiprocessing.reduction.DupFd(fd) 117 name_list.append((self.data_shared, r.dtype, r.shape, shm.name(), df, shm.size())) 118 elif (isinstance(r, np.ndarray) and r.size > self.min_shared_mem 119 and start_bytes + r.nbytes < self.seg_size): 120 # need to convert start_bytes to offset in array 121 start_offset = start_bytes 122 dest = np.ndarray(r.shape, r.dtype, buffer=self.shm_list[self.seg_pos].get_obj(), 123 offset=start_offset) 124 np.copyto(dest, r) 125 byte = r.nbytes 126 byte = 8 * ((byte + 7) // 8) 127 start_bytes += byte 128 name_list.append((self.data_shared, self.seg_pos, byte, r.dtype, r.shape)) 129 else: 130 if isinstance(r, np.ndarray) and r.size > self.min_shared_mem: 131 # Only print out error the first time it happens 132 if self.count.value == 0 and self.print_error: 133 logger.warning( 134 "Using shared memory queue, but rowsize is larger than allocated memory " 135 + "max_rowsize: " 136 + str(self.seg_size / 1024 / 1024) 137 + "MB, current rowsize: " 138 + str((start_bytes + r.nbytes) / 1024 / 1024) 139 + "MB." 140 ) 141 self.print_error = False 142 self.count.value += 1 143 name_list.append((self.data_immediate, r)) 144 super().put(name_list, timeout=timeout) 145 # note above could generate a queue full exception. It will be handled by teh caller 146 # only increment seg_pos after successfully adding to metadata queue 147 148 if start_bytes > 0: 149 self.seg_pos = (self.seg_pos + 1) % self.num_seg 150 151 def get_until(self, timeout=None, exit_signal=None): 152 """Get data from the queue. Block until timeout is reached or exit_signal is set.""" 153 while True: 154 try: 155 r = self.get(timeout=timeout) 156 except queue.Empty as e: 157 if exit_signal is None: 158 raise e 159 if exit_signal.is_set(): 160 return None 161 continue 162 if r is None: 163 # receive finish signal 164 return None 165 if exit_signal.is_set(): 166 # loop until the queue becomes empty 167 continue 168 return r 169 170 def get(self, timeout=None): 171 result = super().get(timeout=timeout) 172 if isinstance(result, ExceptionHandler): 173 return result 174 r = [] 175 start_bytes = 0 176 for x in result: 177 if x[0] == self.data_shared: 178 if self.dynamic_shm: 179 dtype, shape, shm_name, df, buf_size = x[1:] 180 fd = df.detach() 181 shm = cde.SharedMemory(shm_name, False, fd, buf_size) 182 data = np.ndarray(shape, dtype, buffer=shm.buf()) 183 dest = np.copy(data) 184 r.append(dest) 185 else: 186 seg_pos, byte, dtype, shape = x[1:] 187 start_offset = start_bytes 188 b = self.shm_list[seg_pos] 189 data = np.ndarray(shape, dtype, buffer=b.get_obj(), offset=start_offset) 190 start_bytes += byte 191 if self.copy_out: 192 dest = np.copy(data) 193 r.append(dest) 194 else: 195 r.append(data) 196 elif x[0] == self.data_immediate: 197 r.append(x[1]) 198 else: 199 raise RuntimeError("SharedQueue, invalid entry in metadata.") 200 return tuple(r) 201 202 def __del__(self): 203 if not self.dynamic_shm: 204 shm_list_len = len(self.shm_list) 205 for idx in range(shm_list_len): 206 del self.shm_list[shm_list_len - idx - 1] 207 self.shm_list.clear() 208 del self.shm_list 209 210 self.close() 211 self.join_thread() 212 213 214class _Queue(multiprocessing.queues.Queue): 215 """Specialized multiprocessing Queue that supports interrupted operations.""" 216 217 def __init__(self, size): 218 super().__init__(size, ctx=multiprocessing.get_context()) 219 220 def put_until(self, data, timeout=None, exit_signal=None): 221 """Put data into the queue. Block until timeout is reached or exit_signal is set.""" 222 while True: 223 try: 224 self.put(data, timeout=timeout) 225 return 226 except queue.Full as e: 227 if exit_signal is None: 228 raise e 229 if exit_signal.is_set(): 230 return 231 continue 232 233 def get_until(self, timeout=None, exit_signal=None): 234 """Get data from the queue. Block until timeout is reached or exit_signal is set.""" 235 while True: 236 try: 237 r = self.get(timeout=timeout) 238 except queue.Empty as e: 239 if exit_signal is None: 240 raise e 241 if exit_signal.is_set(): 242 return None 243 continue 244 if r is None: 245 # receive finish signal 246 return None 247 if exit_signal.is_set(): 248 # loop until the queue becomes empty 249 continue 250 return r 251